zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Minimum number of milliseconds between recent project entries to keep them
  79const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
  80
  81/// Maximum file path length to include in recent files list.
  82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
  83
  84/// Maximum number of edit predictions to store for feedback.
  85const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Clears the edit prediction history.
  91        ClearHistory
  92    ]
  93);
  94
  95#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  96pub struct EditPredictionId(Uuid);
  97
  98impl From<EditPredictionId> for gpui::ElementId {
  99    fn from(value: EditPredictionId) -> Self {
 100        gpui::ElementId::Uuid(value.0)
 101    }
 102}
 103
 104impl std::fmt::Display for EditPredictionId {
 105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 106        write!(f, "{}", self.0)
 107    }
 108}
 109
 110struct ZedPredictUpsell;
 111
 112impl Dismissable for ZedPredictUpsell {
 113    const KEY: &'static str = "dismissed-edit-predict-upsell";
 114
 115    fn dismissed() -> bool {
 116        // To make this backwards compatible with older versions of Zed, we
 117        // check if the user has seen the previous Edit Prediction Onboarding
 118        // before, by checking the data collection choice which was written to
 119        // the database once the user clicked on "Accept and Enable"
 120        if KEY_VALUE_STORE
 121            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 122            .log_err()
 123            .is_some_and(|s| s.is_some())
 124        {
 125            return true;
 126        }
 127
 128        KEY_VALUE_STORE
 129            .read_kvp(Self::KEY)
 130            .log_err()
 131            .is_some_and(|s| s.is_some())
 132    }
 133}
 134
 135pub fn should_show_upsell_modal() -> bool {
 136    !ZedPredictUpsell::dismissed()
 137}
 138
 139#[derive(Clone)]
 140struct ZetaGlobal(Entity<Zeta>);
 141
 142impl Global for ZetaGlobal {}
 143
 144#[derive(Clone)]
 145pub struct EditPrediction {
 146    id: EditPredictionId,
 147    path: Arc<Path>,
 148    excerpt_range: Range<usize>,
 149    cursor_offset: usize,
 150    edits: Arc<[(Range<Anchor>, String)]>,
 151    snapshot: BufferSnapshot,
 152    edit_preview: EditPreview,
 153    input_outline: Arc<str>,
 154    input_events: Arc<str>,
 155    input_excerpt: Arc<str>,
 156    output_excerpt: Arc<str>,
 157    buffer_snapshotted_at: Instant,
 158    response_received_at: Instant,
 159}
 160
 161impl EditPrediction {
 162    fn latency(&self) -> Duration {
 163        self.response_received_at
 164            .duration_since(self.buffer_snapshotted_at)
 165    }
 166
 167    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 168        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 169    }
 170}
 171
 172fn interpolate(
 173    old_snapshot: &BufferSnapshot,
 174    new_snapshot: &BufferSnapshot,
 175    current_edits: Arc<[(Range<Anchor>, String)]>,
 176) -> Option<Vec<(Range<Anchor>, String)>> {
 177    let mut edits = Vec::new();
 178
 179    let mut model_edits = current_edits.iter().peekable();
 180    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 181        while let Some((model_old_range, _)) = model_edits.peek() {
 182            let model_old_range = model_old_range.to_offset(old_snapshot);
 183            if model_old_range.end < user_edit.old.start {
 184                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 185                edits.push((model_old_range.clone(), model_new_text.clone()));
 186            } else {
 187                break;
 188            }
 189        }
 190
 191        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 192            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 193            if user_edit.old == model_old_offset_range {
 194                let user_new_text = new_snapshot
 195                    .text_for_range(user_edit.new.clone())
 196                    .collect::<String>();
 197
 198                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 199                    if !model_suffix.is_empty() {
 200                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 201                        edits.push((anchor..anchor, model_suffix.to_string()));
 202                    }
 203
 204                    model_edits.next();
 205                    continue;
 206                }
 207            }
 208        }
 209
 210        return None;
 211    }
 212
 213    edits.extend(model_edits.cloned());
 214
 215    if edits.is_empty() { None } else { Some(edits) }
 216}
 217
 218impl std::fmt::Debug for EditPrediction {
 219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 220        f.debug_struct("EditPrediction")
 221            .field("id", &self.id)
 222            .field("path", &self.path)
 223            .field("edits", &self.edits)
 224            .finish_non_exhaustive()
 225    }
 226}
 227
 228pub struct Zeta {
 229    workspace: WeakEntity<Workspace>,
 230    client: Arc<Client>,
 231    events: VecDeque<Event>,
 232    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 233    shown_completions: VecDeque<EditPrediction>,
 234    rated_completions: HashSet<EditPredictionId>,
 235    data_collection_choice: Entity<DataCollectionChoice>,
 236    llm_token: LlmApiToken,
 237    _llm_token_subscription: Subscription,
 238    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 239    update_required: bool,
 240    user_store: Entity<UserStore>,
 241    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 242    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 243}
 244
 245impl Zeta {
 246    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 247        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 248    }
 249
 250    pub fn register(
 251        workspace: Option<Entity<Workspace>>,
 252        worktree: Option<Entity<Worktree>>,
 253        client: Arc<Client>,
 254        user_store: Entity<UserStore>,
 255        cx: &mut App,
 256    ) -> Entity<Self> {
 257        let this = Self::global(cx).unwrap_or_else(|| {
 258            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 259            cx.set_global(ZetaGlobal(entity.clone()));
 260            entity
 261        });
 262
 263        this.update(cx, move |this, cx| {
 264            if let Some(worktree) = worktree {
 265                let worktree_id = worktree.read(cx).id();
 266                this.license_detection_watchers
 267                    .entry(worktree_id)
 268                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 269            }
 270        });
 271
 272        this
 273    }
 274
 275    pub fn clear_history(&mut self) {
 276        self.events.clear();
 277    }
 278
 279    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 280        self.user_store.read(cx).edit_prediction_usage()
 281    }
 282
 283    fn new(
 284        workspace: Option<Entity<Workspace>>,
 285        client: Arc<Client>,
 286        user_store: Entity<UserStore>,
 287        cx: &mut Context<Self>,
 288    ) -> Self {
 289        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 290
 291        let data_collection_choice = Self::load_data_collection_choices();
 292        let data_collection_choice = cx.new(|_| data_collection_choice);
 293
 294        if let Some(workspace) = &workspace {
 295            cx.subscribe(
 296                &workspace.read(cx).project().clone(),
 297                |this, _workspace, event, _cx| match event {
 298                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 299                        this.push_recent_project_entry(*project_entry_id)
 300                    }
 301                    _ => {}
 302                },
 303            )
 304            .detach();
 305        }
 306
 307        Self {
 308            workspace: workspace.map_or_else(
 309                || WeakEntity::new_invalid(),
 310                |workspace| workspace.downgrade(),
 311            ),
 312            client,
 313            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 314            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 315            rated_completions: HashSet::default(),
 316            registered_buffers: HashMap::default(),
 317            data_collection_choice,
 318            llm_token: LlmApiToken::default(),
 319            _llm_token_subscription: cx.subscribe(
 320                &refresh_llm_token_listener,
 321                |this, _listener, _event, cx| {
 322                    let client = this.client.clone();
 323                    let llm_token = this.llm_token.clone();
 324                    cx.spawn(async move |_this, _cx| {
 325                        llm_token.refresh(&client).await?;
 326                        anyhow::Ok(())
 327                    })
 328                    .detach_and_log_err(cx);
 329                },
 330            ),
 331            update_required: false,
 332            license_detection_watchers: HashMap::default(),
 333            user_store,
 334            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 335        }
 336    }
 337
 338    fn push_event(&mut self, event: Event) {
 339        if let Some(Event::BufferChange {
 340            new_snapshot: last_new_snapshot,
 341            timestamp: last_timestamp,
 342            ..
 343        }) = self.events.back_mut()
 344        {
 345            // Coalesce edits for the same buffer when they happen one after the other.
 346            let Event::BufferChange {
 347                old_snapshot,
 348                new_snapshot,
 349                timestamp,
 350            } = &event;
 351
 352            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 353                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 354                && old_snapshot.version == last_new_snapshot.version
 355            {
 356                *last_new_snapshot = new_snapshot.clone();
 357                *last_timestamp = *timestamp;
 358                return;
 359            }
 360        }
 361
 362        if self.events.len() >= MAX_EVENT_COUNT {
 363            // These are halved instead of popping to improve prompt caching.
 364            self.events.drain(..MAX_EVENT_COUNT / 2);
 365        }
 366
 367        self.events.push_back(event);
 368    }
 369
 370    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 371        let buffer_id = buffer.entity_id();
 372        let weak_buffer = buffer.downgrade();
 373
 374        if let std::collections::hash_map::Entry::Vacant(entry) =
 375            self.registered_buffers.entry(buffer_id)
 376        {
 377            let snapshot = buffer.read(cx).snapshot();
 378
 379            entry.insert(RegisteredBuffer {
 380                snapshot,
 381                _subscriptions: [
 382                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 383                        this.handle_buffer_event(buffer, event, cx);
 384                    }),
 385                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 386                        this.registered_buffers.remove(&weak_buffer.entity_id());
 387                    }),
 388                ],
 389            });
 390        };
 391    }
 392
 393    fn handle_buffer_event(
 394        &mut self,
 395        buffer: Entity<Buffer>,
 396        event: &language::BufferEvent,
 397        cx: &mut Context<Self>,
 398    ) {
 399        if let language::BufferEvent::Edited = event {
 400            self.report_changes_for_buffer(&buffer, cx);
 401        }
 402    }
 403
 404    fn request_completion_impl<F, R>(
 405        &mut self,
 406        workspace: Option<Entity<Workspace>>,
 407        project: Option<&Entity<Project>>,
 408        buffer: &Entity<Buffer>,
 409        cursor: language::Anchor,
 410        can_collect_data: CanCollectData,
 411        cx: &mut Context<Self>,
 412        perform_predict_edits: F,
 413    ) -> Task<Result<Option<EditPrediction>>>
 414    where
 415        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 416        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 417            + Send
 418            + 'static,
 419    {
 420        let buffer = buffer.clone();
 421        let buffer_snapshotted_at = Instant::now();
 422        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 423        let zeta = cx.entity();
 424        let events = self.events.clone();
 425        let client = self.client.clone();
 426        let llm_token = self.llm_token.clone();
 427        let app_version = AppVersion::global(cx);
 428
 429        let cursor_point = cursor.to_point(&snapshot);
 430        let cursor_offset = cursor_point.to_offset(&snapshot);
 431        let git_info = if matches!(can_collect_data, CanCollectData(true)) {
 432            self.gather_git_info(
 433                cursor_point.clone(),
 434                &buffer_snapshotted_at,
 435                &snapshot,
 436                project.clone(),
 437                cx,
 438            )
 439        } else {
 440            None
 441        };
 442
 443        let full_path: Arc<Path> = snapshot
 444            .file()
 445            .map(|f| Arc::from(f.full_path(cx).as_path()))
 446            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 447        let full_path_str = full_path.to_string_lossy().to_string();
 448        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 449        let gather_task = gather_context(
 450            project,
 451            full_path_str,
 452            &snapshot,
 453            cursor_point,
 454            make_events_prompt,
 455            can_collect_data,
 456            git_info,
 457            cx,
 458        );
 459
 460        cx.spawn(async move |this, cx| {
 461            let GatherContextOutput {
 462                body,
 463                editable_range,
 464            } = gather_task.await?;
 465            let done_gathering_context_at = Instant::now();
 466
 467            log::debug!(
 468                "Events:\n{}\nExcerpt:\n{:?}",
 469                body.input_events,
 470                body.input_excerpt
 471            );
 472
 473            let input_outline = body.outline.clone().unwrap_or_default();
 474            let input_events = body.input_events.clone();
 475            let input_excerpt = body.input_excerpt.clone();
 476
 477            let response = perform_predict_edits(PerformPredictEditsParams {
 478                client,
 479                llm_token,
 480                app_version,
 481                body,
 482            })
 483            .await;
 484            let (response, usage) = match response {
 485                Ok(response) => response,
 486                Err(err) => {
 487                    if err.is::<ZedUpdateRequiredError>() {
 488                        cx.update(|cx| {
 489                            zeta.update(cx, |zeta, _cx| {
 490                                zeta.update_required = true;
 491                            });
 492
 493                            if let Some(workspace) = workspace {
 494                                workspace.update(cx, |workspace, cx| {
 495                                    workspace.show_notification(
 496                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 497                                        cx,
 498                                        |cx| {
 499                                            cx.new(|cx| {
 500                                                ErrorMessagePrompt::new(err.to_string(), cx)
 501                                                    .with_link_button(
 502                                                        "Update Zed",
 503                                                        "https://zed.dev/releases",
 504                                                    )
 505                                            })
 506                                        },
 507                                    );
 508                                });
 509                            }
 510                        })
 511                        .ok();
 512                    }
 513
 514                    return Err(err);
 515                }
 516            };
 517
 518            let received_response_at = Instant::now();
 519            log::debug!("completion response: {}", &response.output_excerpt);
 520
 521            if let Some(usage) = usage {
 522                this.update(cx, |this, cx| {
 523                    this.user_store.update(cx, |user_store, cx| {
 524                        user_store.update_edit_prediction_usage(usage, cx);
 525                    });
 526                })
 527                .ok();
 528            }
 529
 530            let edit_prediction = Self::process_completion_response(
 531                response,
 532                buffer,
 533                &snapshot,
 534                editable_range,
 535                cursor_offset,
 536                full_path,
 537                input_outline,
 538                input_events,
 539                input_excerpt,
 540                buffer_snapshotted_at,
 541                cx,
 542            )
 543            .await;
 544
 545            let finished_at = Instant::now();
 546
 547            // record latency for ~1% of requests
 548            if rand::random::<u8>() <= 2 {
 549                telemetry::event!(
 550                    "Edit Prediction Request",
 551                    context_latency = done_gathering_context_at
 552                        .duration_since(buffer_snapshotted_at)
 553                        .as_millis(),
 554                    request_latency = received_response_at
 555                        .duration_since(done_gathering_context_at)
 556                        .as_millis(),
 557                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 558                );
 559            }
 560
 561            edit_prediction
 562        })
 563    }
 564
 565    // Generates several example completions of various states to fill the Zeta completion modal
 566    #[cfg(any(test, feature = "test-support"))]
 567    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 568        use language::Point;
 569
 570        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 571            And maybe a short line
 572
 573            Then a few lines
 574
 575            and then another
 576            "#};
 577
 578        let project = None;
 579        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 580        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 581
 582        let completion_tasks = vec![
 583            self.fake_completion(
 584                project,
 585                &buffer,
 586                position,
 587                PredictEditsResponse {
 588                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 589                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 590a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 591[here's an edit]
 592And maybe a short line
 593Then a few lines
 594and then another
 595{EDITABLE_REGION_END_MARKER}
 596                        ", ),
 597                },
 598                cx,
 599            ),
 600            self.fake_completion(
 601                project,
 602                &buffer,
 603                position,
 604                PredictEditsResponse {
 605                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 606                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 607a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 608And maybe a short line
 609[and another edit]
 610Then a few lines
 611and then another
 612{EDITABLE_REGION_END_MARKER}
 613                        "#),
 614                },
 615                cx,
 616            ),
 617            self.fake_completion(
 618                project,
 619                &buffer,
 620                position,
 621                PredictEditsResponse {
 622                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 623                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 624a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 625And maybe a short line
 626
 627Then a few lines
 628
 629and then another
 630{EDITABLE_REGION_END_MARKER}
 631                        "#),
 632                },
 633                cx,
 634            ),
 635            self.fake_completion(
 636                project,
 637                &buffer,
 638                position,
 639                PredictEditsResponse {
 640                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 641                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 642a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 643And maybe a short line
 644
 645Then a few lines
 646
 647and then another
 648{EDITABLE_REGION_END_MARKER}
 649                        "#),
 650                },
 651                cx,
 652            ),
 653            self.fake_completion(
 654                project,
 655                &buffer,
 656                position,
 657                PredictEditsResponse {
 658                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 659                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 660a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 661And maybe a short line
 662Then a few lines
 663[a third completion]
 664and then another
 665{EDITABLE_REGION_END_MARKER}
 666                        "#),
 667                },
 668                cx,
 669            ),
 670            self.fake_completion(
 671                project,
 672                &buffer,
 673                position,
 674                PredictEditsResponse {
 675                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 676                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 677a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 678And maybe a short line
 679and then another
 680[fourth completion example]
 681{EDITABLE_REGION_END_MARKER}
 682                        "#),
 683                },
 684                cx,
 685            ),
 686            self.fake_completion(
 687                project,
 688                &buffer,
 689                position,
 690                PredictEditsResponse {
 691                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 692                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 693a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 694And maybe a short line
 695Then a few lines
 696and then another
 697[fifth and final completion]
 698{EDITABLE_REGION_END_MARKER}
 699                        "#),
 700                },
 701                cx,
 702            ),
 703        ];
 704
 705        cx.spawn(async move |zeta, cx| {
 706            for task in completion_tasks {
 707                task.await.unwrap();
 708            }
 709
 710            zeta.update(cx, |zeta, _cx| {
 711                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 712                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 713            })
 714            .ok();
 715        })
 716    }
 717
 718    #[cfg(any(test, feature = "test-support"))]
 719    pub fn fake_completion(
 720        &mut self,
 721        project: Option<&Entity<Project>>,
 722        buffer: &Entity<Buffer>,
 723        position: language::Anchor,
 724        response: PredictEditsResponse,
 725        cx: &mut Context<Self>,
 726    ) -> Task<Result<Option<EditPrediction>>> {
 727        use std::future::ready;
 728
 729        self.request_completion_impl(
 730            None,
 731            project,
 732            buffer,
 733            position,
 734            CanCollectData(false),
 735            cx,
 736            |_params| ready(Ok((response, None))),
 737        )
 738    }
 739
 740    pub fn request_completion(
 741        &mut self,
 742        project: Option<&Entity<Project>>,
 743        buffer: &Entity<Buffer>,
 744        position: language::Anchor,
 745        can_collect_data: CanCollectData,
 746        cx: &mut Context<Self>,
 747    ) -> Task<Result<Option<EditPrediction>>> {
 748        self.request_completion_impl(
 749            self.workspace.upgrade(),
 750            project,
 751            buffer,
 752            position,
 753            can_collect_data,
 754            cx,
 755            Self::perform_predict_edits,
 756        )
 757    }
 758
 759    pub fn perform_predict_edits(
 760        params: PerformPredictEditsParams,
 761    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 762        async move {
 763            let PerformPredictEditsParams {
 764                client,
 765                llm_token,
 766                app_version,
 767                body,
 768                ..
 769            } = params;
 770
 771            let http_client = client.http_client();
 772            let mut token = llm_token.acquire(&client).await?;
 773            let mut did_retry = false;
 774
 775            loop {
 776                let request_builder = http_client::Request::builder().method(Method::POST);
 777                let request_builder =
 778                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 779                        request_builder.uri(predict_edits_url)
 780                    } else {
 781                        request_builder.uri(
 782                            http_client
 783                                .build_zed_llm_url("/predict_edits/v2", &[])?
 784                                .as_ref(),
 785                        )
 786                    };
 787                let request = request_builder
 788                    .header("Content-Type", "application/json")
 789                    .header("Authorization", format!("Bearer {}", token))
 790                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 791                    .body(serde_json::to_string(&body)?.into())?;
 792
 793                let mut response = http_client.send(request).await?;
 794
 795                if let Some(minimum_required_version) = response
 796                    .headers()
 797                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 798                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 799                {
 800                    anyhow::ensure!(
 801                        app_version >= minimum_required_version,
 802                        ZedUpdateRequiredError {
 803                            minimum_version: minimum_required_version
 804                        }
 805                    );
 806                }
 807
 808                if response.status().is_success() {
 809                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 810
 811                    let mut body = String::new();
 812                    response.body_mut().read_to_string(&mut body).await?;
 813                    return Ok((serde_json::from_str(&body)?, usage));
 814                } else if !did_retry
 815                    && response
 816                        .headers()
 817                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 818                        .is_some()
 819                {
 820                    did_retry = true;
 821                    token = llm_token.refresh(&client).await?;
 822                } else {
 823                    let mut body = String::new();
 824                    response.body_mut().read_to_string(&mut body).await?;
 825                    anyhow::bail!(
 826                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 827                        response.status(),
 828                        body
 829                    );
 830                }
 831            }
 832        }
 833    }
 834
 835    fn accept_edit_prediction(
 836        &mut self,
 837        request_id: EditPredictionId,
 838        cx: &mut Context<Self>,
 839    ) -> Task<Result<()>> {
 840        let client = self.client.clone();
 841        let llm_token = self.llm_token.clone();
 842        let app_version = AppVersion::global(cx);
 843        cx.spawn(async move |this, cx| {
 844            let http_client = client.http_client();
 845            let mut response = llm_token_retry(&llm_token, &client, |token| {
 846                let request_builder = http_client::Request::builder().method(Method::POST);
 847                let request_builder =
 848                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 849                        request_builder.uri(accept_prediction_url)
 850                    } else {
 851                        request_builder.uri(
 852                            http_client
 853                                .build_zed_llm_url("/predict_edits/accept", &[])?
 854                                .as_ref(),
 855                        )
 856                    };
 857                Ok(request_builder
 858                    .header("Content-Type", "application/json")
 859                    .header("Authorization", format!("Bearer {}", token))
 860                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 861                    .body(
 862                        serde_json::to_string(&AcceptEditPredictionBody {
 863                            request_id: request_id.0,
 864                        })?
 865                        .into(),
 866                    )?)
 867            })
 868            .await?;
 869
 870            if let Some(minimum_required_version) = response
 871                .headers()
 872                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 873                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 874                && app_version < minimum_required_version
 875            {
 876                return Err(anyhow!(ZedUpdateRequiredError {
 877                    minimum_version: minimum_required_version
 878                }));
 879            }
 880
 881            if response.status().is_success() {
 882                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 883                    this.update(cx, |this, cx| {
 884                        this.user_store.update(cx, |user_store, cx| {
 885                            user_store.update_edit_prediction_usage(usage, cx);
 886                        });
 887                    })?;
 888                }
 889
 890                Ok(())
 891            } else {
 892                let mut body = String::new();
 893                response.body_mut().read_to_string(&mut body).await?;
 894                Err(anyhow!(
 895                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 896                    response.status(),
 897                    body
 898                ))
 899            }
 900        })
 901    }
 902
 903    fn process_completion_response(
 904        prediction_response: PredictEditsResponse,
 905        buffer: Entity<Buffer>,
 906        snapshot: &BufferSnapshot,
 907        editable_range: Range<usize>,
 908        cursor_offset: usize,
 909        path: Arc<Path>,
 910        input_outline: String,
 911        input_events: String,
 912        input_excerpt: String,
 913        buffer_snapshotted_at: Instant,
 914        cx: &AsyncApp,
 915    ) -> Task<Result<Option<EditPrediction>>> {
 916        let snapshot = snapshot.clone();
 917        let request_id = prediction_response.request_id;
 918        let output_excerpt = prediction_response.output_excerpt;
 919        cx.spawn(async move |cx| {
 920            let output_excerpt: Arc<str> = output_excerpt.into();
 921
 922            let edits: Arc<[(Range<Anchor>, String)]> = cx
 923                .background_spawn({
 924                    let output_excerpt = output_excerpt.clone();
 925                    let editable_range = editable_range.clone();
 926                    let snapshot = snapshot.clone();
 927                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 928                })
 929                .await?
 930                .into();
 931
 932            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 933                let edits = edits.clone();
 934                |buffer, cx| {
 935                    let new_snapshot = buffer.snapshot();
 936                    let edits: Arc<[(Range<Anchor>, String)]> =
 937                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 938                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 939                }
 940            })?
 941            else {
 942                return anyhow::Ok(None);
 943            };
 944
 945            let edit_preview = edit_preview.await;
 946
 947            Ok(Some(EditPrediction {
 948                id: EditPredictionId(request_id),
 949                path,
 950                excerpt_range: editable_range,
 951                cursor_offset,
 952                edits,
 953                edit_preview,
 954                snapshot,
 955                input_outline: input_outline.into(),
 956                input_events: input_events.into(),
 957                input_excerpt: input_excerpt.into(),
 958                output_excerpt,
 959                buffer_snapshotted_at,
 960                response_received_at: Instant::now(),
 961            }))
 962        })
 963    }
 964
 965    fn parse_edits(
 966        output_excerpt: Arc<str>,
 967        editable_range: Range<usize>,
 968        snapshot: &BufferSnapshot,
 969    ) -> Result<Vec<(Range<Anchor>, String)>> {
 970        let content = output_excerpt.replace(CURSOR_MARKER, "");
 971
 972        let start_markers = content
 973            .match_indices(EDITABLE_REGION_START_MARKER)
 974            .collect::<Vec<_>>();
 975        anyhow::ensure!(
 976            start_markers.len() == 1,
 977            "expected exactly one start marker, found {}",
 978            start_markers.len()
 979        );
 980
 981        let end_markers = content
 982            .match_indices(EDITABLE_REGION_END_MARKER)
 983            .collect::<Vec<_>>();
 984        anyhow::ensure!(
 985            end_markers.len() == 1,
 986            "expected exactly one end marker, found {}",
 987            end_markers.len()
 988        );
 989
 990        let sof_markers = content
 991            .match_indices(START_OF_FILE_MARKER)
 992            .collect::<Vec<_>>();
 993        anyhow::ensure!(
 994            sof_markers.len() <= 1,
 995            "expected at most one start-of-file marker, found {}",
 996            sof_markers.len()
 997        );
 998
 999        let codefence_start = start_markers[0].0;
1000        let content = &content[codefence_start..];
1001
1002        let newline_ix = content.find('\n').context("could not find newline")?;
1003        let content = &content[newline_ix + 1..];
1004
1005        let codefence_end = content
1006            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1007            .context("could not find end marker")?;
1008        let new_text = &content[..codefence_end];
1009
1010        let old_text = snapshot
1011            .text_for_range(editable_range.clone())
1012            .collect::<String>();
1013
1014        Ok(Self::compute_edits(
1015            old_text,
1016            new_text,
1017            editable_range.start,
1018            snapshot,
1019        ))
1020    }
1021
1022    pub fn compute_edits(
1023        old_text: String,
1024        new_text: &str,
1025        offset: usize,
1026        snapshot: &BufferSnapshot,
1027    ) -> Vec<(Range<Anchor>, String)> {
1028        text_diff(&old_text, new_text)
1029            .into_iter()
1030            .map(|(mut old_range, new_text)| {
1031                old_range.start += offset;
1032                old_range.end += offset;
1033
1034                let prefix_len = common_prefix(
1035                    snapshot.chars_for_range(old_range.clone()),
1036                    new_text.chars(),
1037                );
1038                old_range.start += prefix_len;
1039
1040                let suffix_len = common_prefix(
1041                    snapshot.reversed_chars_for_range(old_range.clone()),
1042                    new_text[prefix_len..].chars().rev(),
1043                );
1044                old_range.end = old_range.end.saturating_sub(suffix_len);
1045
1046                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1047                let range = if old_range.is_empty() {
1048                    let anchor = snapshot.anchor_after(old_range.start);
1049                    anchor..anchor
1050                } else {
1051                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1052                };
1053                (range, new_text)
1054            })
1055            .collect()
1056    }
1057
1058    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1059        self.rated_completions.contains(&completion_id)
1060    }
1061
1062    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1063        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1064            let completion = self.shown_completions.pop_back().unwrap();
1065            self.rated_completions.remove(&completion.id);
1066        }
1067        self.shown_completions.push_front(completion.clone());
1068        cx.notify();
1069    }
1070
1071    pub fn rate_completion(
1072        &mut self,
1073        completion: &EditPrediction,
1074        rating: EditPredictionRating,
1075        feedback: String,
1076        cx: &mut Context<Self>,
1077    ) {
1078        self.rated_completions.insert(completion.id);
1079        telemetry::event!(
1080            "Edit Prediction Rated",
1081            rating,
1082            input_events = completion.input_events,
1083            input_excerpt = completion.input_excerpt,
1084            input_outline = completion.input_outline,
1085            output_excerpt = completion.output_excerpt,
1086            feedback
1087        );
1088        self.client.telemetry().flush_events().detach();
1089        cx.notify();
1090    }
1091
1092    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1093        self.shown_completions.iter()
1094    }
1095
1096    pub fn shown_completions_len(&self) -> usize {
1097        self.shown_completions.len()
1098    }
1099
1100    fn report_changes_for_buffer(
1101        &mut self,
1102        buffer: &Entity<Buffer>,
1103        cx: &mut Context<Self>,
1104    ) -> BufferSnapshot {
1105        self.register_buffer(buffer, cx);
1106
1107        let registered_buffer = self
1108            .registered_buffers
1109            .get_mut(&buffer.entity_id())
1110            .unwrap();
1111        let new_snapshot = buffer.read(cx).snapshot();
1112
1113        if new_snapshot.version != registered_buffer.snapshot.version {
1114            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1115            self.push_event(Event::BufferChange {
1116                old_snapshot,
1117                new_snapshot: new_snapshot.clone(),
1118                timestamp: Instant::now(),
1119            });
1120        }
1121
1122        new_snapshot
1123    }
1124
1125    fn load_data_collection_choices() -> DataCollectionChoice {
1126        let choice = KEY_VALUE_STORE
1127            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1128            .log_err()
1129            .flatten();
1130
1131        match choice.as_deref() {
1132            Some("true") => DataCollectionChoice::Enabled,
1133            Some("false") => DataCollectionChoice::Disabled,
1134            Some(_) => {
1135                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1136                DataCollectionChoice::NotAnswered
1137            }
1138            None => DataCollectionChoice::NotAnswered,
1139        }
1140    }
1141
1142    fn gather_git_info(
1143        &mut self,
1144        cursor_point: language::Point,
1145        buffer_snapshotted_at: &Instant,
1146        snapshot: &BufferSnapshot,
1147        project: Option<&Entity<Project>>,
1148        cx: &Context<Self>,
1149    ) -> Option<PredictEditsGitInfo> {
1150        let project = project?.read(cx);
1151        let file = snapshot.file()?;
1152        let project_path = ProjectPath::from_file(file.as_ref(), cx);
1153        let entry = project.entry_for_path(&project_path, cx)?;
1154        if !worktree_entry_eligible_for_collection(&entry) {
1155            return None;
1156        }
1157
1158        let git_store = project.git_store().read(cx);
1159        let (repository, repo_path) =
1160            git_store.repository_and_path_for_project_path(&project_path, cx)?;
1161        let repo_path_str = repo_path.to_str()?;
1162
1163        let repository = repository.read(cx);
1164        let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1165        let remote_origin_url = repository.remote_origin_url.clone();
1166        let remote_upstream_url = repository.remote_upstream_url.clone();
1167        let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1168
1169        Some(PredictEditsGitInfo {
1170            input_path: Some(repo_path_str.to_string()),
1171            cursor_point: Some(cloud_llm_client::Point {
1172                row: cursor_point.row,
1173                column: cursor_point.column,
1174            }),
1175            cursor_offset: Some(cursor_offset),
1176            head_sha: Some(head_sha),
1177            remote_origin_url,
1178            remote_upstream_url,
1179            recent_files: Some(recent_files),
1180        })
1181    }
1182
1183    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1184        let now = Instant::now();
1185        if let Some(existing_ix) = self
1186            .recent_project_entries
1187            .iter()
1188            .rposition(|(id, _)| *id == project_entry_id)
1189        {
1190            self.recent_project_entries.remove(existing_ix);
1191        }
1192        // filter out rapid changes in active item, particularly since this can happen rapidly when
1193        // a workspace is loaded.
1194        if let Some(most_recent) = self.recent_project_entries.back_mut()
1195            && now.duration_since(most_recent.1) > MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES
1196        {
1197            most_recent.0 = project_entry_id;
1198            most_recent.1 = now;
1199            return;
1200        }
1201        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1202            self.recent_project_entries.pop_front();
1203        }
1204        self.recent_project_entries
1205            .push_back((project_entry_id, now));
1206    }
1207
1208    fn recent_files(
1209        &mut self,
1210        now: &Instant,
1211        repository: &Repository,
1212        cx: &Context<Self>,
1213    ) -> Vec<PredictEditsRecentFile> {
1214        let Ok(project) = self
1215            .workspace
1216            .read_with(cx, |workspace, _cx| workspace.project().clone())
1217        else {
1218            return Vec::new();
1219        };
1220        let mut results = Vec::new();
1221        for ix in (0..self.recent_project_entries.len()).rev() {
1222            let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1223            if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1224                && let worktree = worktree.read(cx)
1225                && let Some(entry) = worktree.entry_for_id(*entry_id)
1226                && worktree_entry_eligible_for_collection(entry)
1227            {
1228                let project_path = ProjectPath {
1229                    worktree_id: worktree.id(),
1230                    path: entry.path.clone(),
1231                };
1232                let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1233                else {
1234                    // entry not removed since queries involving other repositories might occur later
1235                    continue;
1236                };
1237                let Some(repo_path_str) = repo_path.to_str() else {
1238                    // paths may not be valid UTF-8
1239                    self.recent_project_entries.remove(ix);
1240                    continue;
1241                };
1242                if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1243                    self.recent_project_entries.remove(ix);
1244                    continue;
1245                }
1246                let Ok(active_to_now_ms) =
1247                    now.duration_since(*last_active_at).as_millis().try_into()
1248                else {
1249                    self.recent_project_entries.remove(ix);
1250                    continue;
1251                };
1252                results.push(PredictEditsRecentFile {
1253                    path: repo_path_str.to_string(),
1254                    active_to_now_ms,
1255                });
1256            } else {
1257                self.recent_project_entries.remove(ix);
1258            }
1259        }
1260        results
1261    }
1262}
1263
1264fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1265    entry.is_file()
1266        && entry.is_created()
1267        && !entry.is_ignored
1268        && !entry.is_private
1269        && !entry.is_external
1270        && !entry.is_fifo
1271}
1272
1273pub struct PerformPredictEditsParams {
1274    pub client: Arc<Client>,
1275    pub llm_token: LlmApiToken,
1276    pub app_version: SemanticVersion,
1277    pub body: PredictEditsBody,
1278}
1279
1280#[derive(Error, Debug)]
1281#[error(
1282    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1283)]
1284pub struct ZedUpdateRequiredError {
1285    minimum_version: SemanticVersion,
1286}
1287
1288fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1289    a.zip(b)
1290        .take_while(|(a, b)| a == b)
1291        .map(|(a, _)| a.len_utf8())
1292        .sum()
1293}
1294
1295pub struct GatherContextOutput {
1296    pub body: PredictEditsBody,
1297    pub editable_range: Range<usize>,
1298}
1299
1300pub fn gather_context(
1301    project: Option<&Entity<Project>>,
1302    full_path_str: String,
1303    snapshot: &BufferSnapshot,
1304    cursor_point: language::Point,
1305    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1306    can_collect_data: CanCollectData,
1307    git_info: Option<PredictEditsGitInfo>,
1308    cx: &App,
1309) -> Task<Result<GatherContextOutput>> {
1310    let local_lsp_store =
1311        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1312    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1313        if matches!(can_collect_data, CanCollectData(true))
1314            && let Some(local_lsp_store) = local_lsp_store
1315        {
1316            snapshot
1317                .diagnostic_groups(None)
1318                .into_iter()
1319                .filter_map(|(language_server_id, diagnostic_group)| {
1320                    let language_server =
1321                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1322                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1323                    let language_server_name = language_server.name().to_string();
1324                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1325                    Some((language_server_name, serialized))
1326                })
1327                .collect::<Vec<_>>()
1328        } else {
1329            Vec::new()
1330        };
1331
1332    cx.background_spawn({
1333        let snapshot = snapshot.clone();
1334        async move {
1335            let diagnostic_groups = if diagnostic_groups.is_empty()
1336                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1337            {
1338                None
1339            } else {
1340                Some(diagnostic_groups)
1341            };
1342
1343            let input_excerpt = excerpt_for_cursor_position(
1344                cursor_point,
1345                &full_path_str,
1346                &snapshot,
1347                MAX_REWRITE_TOKENS,
1348                MAX_CONTEXT_TOKENS,
1349            );
1350            let input_events = make_events_prompt();
1351            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1352
1353            let body = PredictEditsBody {
1354                input_events,
1355                input_excerpt: input_excerpt.prompt,
1356                can_collect_data: can_collect_data.0,
1357                diagnostic_groups,
1358                git_info,
1359                outline: None,
1360                speculated_output: None,
1361            };
1362
1363            Ok(GatherContextOutput {
1364                body,
1365                editable_range,
1366            })
1367        }
1368    })
1369}
1370
1371fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1372    let mut result = String::new();
1373    for event in events.iter().rev() {
1374        let event_string = event.to_prompt();
1375        let event_tokens = tokens_for_bytes(event_string.len());
1376        if event_tokens > remaining_tokens {
1377            break;
1378        }
1379
1380        if !result.is_empty() {
1381            result.insert_str(0, "\n\n");
1382        }
1383        result.insert_str(0, &event_string);
1384        remaining_tokens -= event_tokens;
1385    }
1386    result
1387}
1388
1389struct RegisteredBuffer {
1390    snapshot: BufferSnapshot,
1391    _subscriptions: [gpui::Subscription; 2],
1392}
1393
1394#[derive(Clone)]
1395pub enum Event {
1396    BufferChange {
1397        old_snapshot: BufferSnapshot,
1398        new_snapshot: BufferSnapshot,
1399        timestamp: Instant,
1400    },
1401}
1402
1403impl Event {
1404    fn to_prompt(&self) -> String {
1405        match self {
1406            Event::BufferChange {
1407                old_snapshot,
1408                new_snapshot,
1409                ..
1410            } => {
1411                let mut prompt = String::new();
1412
1413                let old_path = old_snapshot
1414                    .file()
1415                    .map(|f| f.path().as_ref())
1416                    .unwrap_or(Path::new("untitled"));
1417                let new_path = new_snapshot
1418                    .file()
1419                    .map(|f| f.path().as_ref())
1420                    .unwrap_or(Path::new("untitled"));
1421                if old_path != new_path {
1422                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1423                }
1424
1425                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1426                if !diff.is_empty() {
1427                    write!(
1428                        prompt,
1429                        "User edited {:?}:\n```diff\n{}\n```",
1430                        new_path, diff
1431                    )
1432                    .unwrap();
1433                }
1434
1435                prompt
1436            }
1437        }
1438    }
1439}
1440
1441#[derive(Debug, Clone)]
1442struct CurrentEditPrediction {
1443    buffer_id: EntityId,
1444    completion: EditPrediction,
1445}
1446
1447impl CurrentEditPrediction {
1448    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1449        if self.buffer_id != old_completion.buffer_id {
1450            return true;
1451        }
1452
1453        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1454            return true;
1455        };
1456        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1457            return false;
1458        };
1459
1460        if old_edits.len() == 1 && new_edits.len() == 1 {
1461            let (old_range, old_text) = &old_edits[0];
1462            let (new_range, new_text) = &new_edits[0];
1463            new_range == old_range && new_text.starts_with(old_text)
1464        } else {
1465            true
1466        }
1467    }
1468}
1469
1470struct PendingCompletion {
1471    id: usize,
1472    _task: Task<()>,
1473}
1474
1475#[derive(Debug, Clone, Copy)]
1476pub enum DataCollectionChoice {
1477    NotAnswered,
1478    Enabled,
1479    Disabled,
1480}
1481
1482impl DataCollectionChoice {
1483    pub fn is_enabled(self) -> bool {
1484        match self {
1485            Self::Enabled => true,
1486            Self::NotAnswered | Self::Disabled => false,
1487        }
1488    }
1489
1490    pub fn is_answered(self) -> bool {
1491        match self {
1492            Self::Enabled | Self::Disabled => true,
1493            Self::NotAnswered => false,
1494        }
1495    }
1496
1497    pub fn toggle(&self) -> DataCollectionChoice {
1498        match self {
1499            Self::Enabled => Self::Disabled,
1500            Self::Disabled => Self::Enabled,
1501            Self::NotAnswered => Self::Enabled,
1502        }
1503    }
1504}
1505
1506impl From<bool> for DataCollectionChoice {
1507    fn from(value: bool) -> Self {
1508        match value {
1509            true => DataCollectionChoice::Enabled,
1510            false => DataCollectionChoice::Disabled,
1511        }
1512    }
1513}
1514
1515pub struct ProviderDataCollection {
1516    /// When set to None, data collection is not possible in the provider buffer
1517    choice: Option<Entity<DataCollectionChoice>>,
1518    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1519}
1520
1521#[derive(Debug, Clone, Copy)]
1522pub struct CanCollectData(pub bool);
1523
1524impl ProviderDataCollection {
1525    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1526        let choice_and_watcher = buffer.and_then(|buffer| {
1527            let file = buffer.read(cx).file()?;
1528
1529            if !file.is_local() || file.is_private() {
1530                return None;
1531            }
1532
1533            let zeta = zeta.read(cx);
1534            let choice = zeta.data_collection_choice.clone();
1535
1536            let license_detection_watcher = zeta
1537                .license_detection_watchers
1538                .get(&file.worktree_id(cx))
1539                .cloned()?;
1540
1541            Some((choice, license_detection_watcher))
1542        });
1543
1544        if let Some((choice, watcher)) = choice_and_watcher {
1545            ProviderDataCollection {
1546                choice: Some(choice),
1547                license_detection_watcher: Some(watcher),
1548            }
1549        } else {
1550            ProviderDataCollection {
1551                choice: None,
1552                license_detection_watcher: None,
1553            }
1554        }
1555    }
1556
1557    pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1558        CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1559    }
1560
1561    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1562        self.choice
1563            .as_ref()
1564            .is_some_and(|choice| choice.read(cx).is_enabled())
1565    }
1566
1567    fn is_project_open_source(&self) -> bool {
1568        self.license_detection_watcher
1569            .as_ref()
1570            .is_some_and(|watcher| watcher.is_project_open_source())
1571    }
1572
1573    pub fn toggle(&mut self, cx: &mut App) {
1574        if let Some(choice) = self.choice.as_mut() {
1575            let new_choice = choice.update(cx, |choice, _cx| {
1576                let new_choice = choice.toggle();
1577                *choice = new_choice;
1578                new_choice
1579            });
1580
1581            db::write_and_log(cx, move || {
1582                KEY_VALUE_STORE.write_kvp(
1583                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1584                    new_choice.is_enabled().to_string(),
1585                )
1586            });
1587        }
1588    }
1589}
1590
1591async fn llm_token_retry(
1592    llm_token: &LlmApiToken,
1593    client: &Arc<Client>,
1594    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1595) -> Result<Response<AsyncBody>> {
1596    let mut did_retry = false;
1597    let http_client = client.http_client();
1598    let mut token = llm_token.acquire(client).await?;
1599    loop {
1600        let request = build_request(token.clone())?;
1601        let response = http_client.send(request).await?;
1602
1603        if !did_retry
1604            && !response.status().is_success()
1605            && response
1606                .headers()
1607                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1608                .is_some()
1609        {
1610            did_retry = true;
1611            token = llm_token.refresh(client).await?;
1612            continue;
1613        }
1614
1615        return Ok(response);
1616    }
1617}
1618
1619pub struct ZetaEditPredictionProvider {
1620    zeta: Entity<Zeta>,
1621    pending_completions: ArrayVec<PendingCompletion, 2>,
1622    next_pending_completion_id: usize,
1623    current_completion: Option<CurrentEditPrediction>,
1624    /// None if this is entirely disabled for this provider
1625    provider_data_collection: ProviderDataCollection,
1626    last_request_timestamp: Instant,
1627}
1628
1629impl ZetaEditPredictionProvider {
1630    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1631
1632    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1633        Self {
1634            zeta,
1635            pending_completions: ArrayVec::new(),
1636            next_pending_completion_id: 0,
1637            current_completion: None,
1638            provider_data_collection,
1639            last_request_timestamp: Instant::now(),
1640        }
1641    }
1642}
1643
1644impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1645    fn name() -> &'static str {
1646        "zed-predict"
1647    }
1648
1649    fn display_name() -> &'static str {
1650        "Zed's Edit Predictions"
1651    }
1652
1653    fn show_completions_in_menu() -> bool {
1654        true
1655    }
1656
1657    fn show_tab_accept_marker() -> bool {
1658        true
1659    }
1660
1661    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1662        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1663
1664        if self.provider_data_collection.is_data_collection_enabled(cx) {
1665            DataCollectionState::Enabled {
1666                is_project_open_source,
1667            }
1668        } else {
1669            DataCollectionState::Disabled {
1670                is_project_open_source,
1671            }
1672        }
1673    }
1674
1675    fn toggle_data_collection(&mut self, cx: &mut App) {
1676        self.provider_data_collection.toggle(cx);
1677    }
1678
1679    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1680        self.zeta.read(cx).usage(cx)
1681    }
1682
1683    fn is_enabled(
1684        &self,
1685        _buffer: &Entity<Buffer>,
1686        _cursor_position: language::Anchor,
1687        _cx: &App,
1688    ) -> bool {
1689        true
1690    }
1691    fn is_refreshing(&self) -> bool {
1692        !self.pending_completions.is_empty()
1693    }
1694
1695    fn refresh(
1696        &mut self,
1697        project: Option<Entity<Project>>,
1698        buffer: Entity<Buffer>,
1699        position: language::Anchor,
1700        _debounce: bool,
1701        cx: &mut Context<Self>,
1702    ) {
1703        if self.zeta.read(cx).update_required {
1704            return;
1705        }
1706
1707        if self
1708            .zeta
1709            .read(cx)
1710            .user_store
1711            .read_with(cx, |user_store, _cx| {
1712                user_store.account_too_young() || user_store.has_overdue_invoices()
1713            })
1714        {
1715            return;
1716        }
1717
1718        if let Some(current_completion) = self.current_completion.as_ref() {
1719            let snapshot = buffer.read(cx).snapshot();
1720            if current_completion
1721                .completion
1722                .interpolate(&snapshot)
1723                .is_some()
1724            {
1725                return;
1726            }
1727        }
1728
1729        let pending_completion_id = self.next_pending_completion_id;
1730        self.next_pending_completion_id += 1;
1731        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1732        let last_request_timestamp = self.last_request_timestamp;
1733
1734        let task = cx.spawn(async move |this, cx| {
1735            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1736                .checked_duration_since(Instant::now())
1737            {
1738                cx.background_executor().timer(timeout).await;
1739            }
1740
1741            let completion_request = this.update(cx, |this, cx| {
1742                this.last_request_timestamp = Instant::now();
1743                this.zeta.update(cx, |zeta, cx| {
1744                    zeta.request_completion(
1745                        project.as_ref(),
1746                        &buffer,
1747                        position,
1748                        can_collect_data,
1749                        cx,
1750                    )
1751                })
1752            });
1753
1754            let completion = match completion_request {
1755                Ok(completion_request) => {
1756                    let completion_request = completion_request.await;
1757                    completion_request.map(|c| {
1758                        c.map(|completion| CurrentEditPrediction {
1759                            buffer_id: buffer.entity_id(),
1760                            completion,
1761                        })
1762                    })
1763                }
1764                Err(error) => Err(error),
1765            };
1766            let Some(new_completion) = completion
1767                .context("edit prediction failed")
1768                .log_err()
1769                .flatten()
1770            else {
1771                this.update(cx, |this, cx| {
1772                    if this.pending_completions[0].id == pending_completion_id {
1773                        this.pending_completions.remove(0);
1774                    } else {
1775                        this.pending_completions.clear();
1776                    }
1777
1778                    cx.notify();
1779                })
1780                .ok();
1781                return;
1782            };
1783
1784            this.update(cx, |this, cx| {
1785                if this.pending_completions[0].id == pending_completion_id {
1786                    this.pending_completions.remove(0);
1787                } else {
1788                    this.pending_completions.clear();
1789                }
1790
1791                if let Some(old_completion) = this.current_completion.as_ref() {
1792                    let snapshot = buffer.read(cx).snapshot();
1793                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1794                        this.zeta.update(cx, |zeta, cx| {
1795                            zeta.completion_shown(&new_completion.completion, cx);
1796                        });
1797                        this.current_completion = Some(new_completion);
1798                    }
1799                } else {
1800                    this.zeta.update(cx, |zeta, cx| {
1801                        zeta.completion_shown(&new_completion.completion, cx);
1802                    });
1803                    this.current_completion = Some(new_completion);
1804                }
1805
1806                cx.notify();
1807            })
1808            .ok();
1809        });
1810
1811        // We always maintain at most two pending completions. When we already
1812        // have two, we replace the newest one.
1813        if self.pending_completions.len() <= 1 {
1814            self.pending_completions.push(PendingCompletion {
1815                id: pending_completion_id,
1816                _task: task,
1817            });
1818        } else if self.pending_completions.len() == 2 {
1819            self.pending_completions.pop();
1820            self.pending_completions.push(PendingCompletion {
1821                id: pending_completion_id,
1822                _task: task,
1823            });
1824        }
1825    }
1826
1827    fn cycle(
1828        &mut self,
1829        _buffer: Entity<Buffer>,
1830        _cursor_position: language::Anchor,
1831        _direction: edit_prediction::Direction,
1832        _cx: &mut Context<Self>,
1833    ) {
1834        // Right now we don't support cycling.
1835    }
1836
1837    fn accept(&mut self, cx: &mut Context<Self>) {
1838        let completion_id = self
1839            .current_completion
1840            .as_ref()
1841            .map(|completion| completion.completion.id);
1842        if let Some(completion_id) = completion_id {
1843            self.zeta
1844                .update(cx, |zeta, cx| {
1845                    zeta.accept_edit_prediction(completion_id, cx)
1846                })
1847                .detach();
1848        }
1849        self.pending_completions.clear();
1850    }
1851
1852    fn discard(&mut self, _cx: &mut Context<Self>) {
1853        self.pending_completions.clear();
1854        self.current_completion.take();
1855    }
1856
1857    fn suggest(
1858        &mut self,
1859        buffer: &Entity<Buffer>,
1860        cursor_position: language::Anchor,
1861        cx: &mut Context<Self>,
1862    ) -> Option<edit_prediction::EditPrediction> {
1863        let CurrentEditPrediction {
1864            buffer_id,
1865            completion,
1866            ..
1867        } = self.current_completion.as_mut()?;
1868
1869        // Invalidate previous completion if it was generated for a different buffer.
1870        if *buffer_id != buffer.entity_id() {
1871            self.current_completion.take();
1872            return None;
1873        }
1874
1875        let buffer = buffer.read(cx);
1876        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1877            self.current_completion.take();
1878            return None;
1879        };
1880
1881        let cursor_row = cursor_position.to_point(buffer).row;
1882        let (closest_edit_ix, (closest_edit_range, _)) =
1883            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1884                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1885                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1886                cmp::min(distance_from_start, distance_from_end)
1887            })?;
1888
1889        let mut edit_start_ix = closest_edit_ix;
1890        for (range, _) in edits[..edit_start_ix].iter().rev() {
1891            let distance_from_closest_edit =
1892                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1893            if distance_from_closest_edit <= 1 {
1894                edit_start_ix -= 1;
1895            } else {
1896                break;
1897            }
1898        }
1899
1900        let mut edit_end_ix = closest_edit_ix + 1;
1901        for (range, _) in &edits[edit_end_ix..] {
1902            let distance_from_closest_edit =
1903                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1904            if distance_from_closest_edit <= 1 {
1905                edit_end_ix += 1;
1906            } else {
1907                break;
1908            }
1909        }
1910
1911        Some(edit_prediction::EditPrediction {
1912            id: Some(completion.id.to_string().into()),
1913            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1914            edit_preview: Some(completion.edit_preview.clone()),
1915        })
1916    }
1917}
1918
1919fn tokens_for_bytes(bytes: usize) -> usize {
1920    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1921    /// intentionally low to err on the side of underestimating limits.
1922    const BYTES_PER_TOKEN_GUESS: usize = 3;
1923    bytes / BYTES_PER_TOKEN_GUESS
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928    use client::UserStore;
1929    use client::test::FakeServer;
1930    use clock::FakeSystemClock;
1931    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1932    use gpui::TestAppContext;
1933    use http_client::FakeHttpClient;
1934    use indoc::indoc;
1935    use language::Point;
1936    use settings::SettingsStore;
1937
1938    use super::*;
1939
1940    #[gpui::test]
1941    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1942        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1943        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1944            to_completion_edits(
1945                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1946                &buffer,
1947                cx,
1948            )
1949            .into()
1950        });
1951
1952        let edit_preview = cx
1953            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1954            .await;
1955
1956        let completion = EditPrediction {
1957            edits,
1958            edit_preview,
1959            path: Path::new("").into(),
1960            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1961            id: EditPredictionId(Uuid::new_v4()),
1962            excerpt_range: 0..0,
1963            cursor_offset: 0,
1964            input_outline: "".into(),
1965            input_events: "".into(),
1966            input_excerpt: "".into(),
1967            output_excerpt: "".into(),
1968            buffer_snapshotted_at: Instant::now(),
1969            response_received_at: Instant::now(),
1970        };
1971
1972        cx.update(|cx| {
1973            assert_eq!(
1974                from_completion_edits(
1975                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1976                    &buffer,
1977                    cx
1978                ),
1979                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1980            );
1981
1982            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1983            assert_eq!(
1984                from_completion_edits(
1985                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1986                    &buffer,
1987                    cx
1988                ),
1989                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1990            );
1991
1992            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1993            assert_eq!(
1994                from_completion_edits(
1995                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1996                    &buffer,
1997                    cx
1998                ),
1999                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2000            );
2001
2002            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2003            assert_eq!(
2004                from_completion_edits(
2005                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2006                    &buffer,
2007                    cx
2008                ),
2009                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2010            );
2011
2012            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2013            assert_eq!(
2014                from_completion_edits(
2015                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2016                    &buffer,
2017                    cx
2018                ),
2019                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2020            );
2021
2022            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2023            assert_eq!(
2024                from_completion_edits(
2025                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2026                    &buffer,
2027                    cx
2028                ),
2029                vec![(9..11, "".to_string())]
2030            );
2031
2032            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2033            assert_eq!(
2034                from_completion_edits(
2035                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2036                    &buffer,
2037                    cx
2038                ),
2039                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2040            );
2041
2042            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2043            assert_eq!(
2044                from_completion_edits(
2045                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2046                    &buffer,
2047                    cx
2048                ),
2049                vec![(4..4, "M".to_string())]
2050            );
2051
2052            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2053            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2054        })
2055    }
2056
2057    #[gpui::test]
2058    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2059        cx.update(|cx| {
2060            let settings_store = SettingsStore::test(cx);
2061            cx.set_global(settings_store);
2062            client::init_settings(cx);
2063        });
2064
2065        let edits = edits_for_prediction(
2066            indoc! {"
2067                fn main() {
2068                    let word_1 = \"lorem\";
2069                    let range = word.len()..word.len();
2070                }
2071            "},
2072            indoc! {"
2073                <|editable_region_start|>
2074                fn main() {
2075                    let word_1 = \"lorem\";
2076                    let range = word_1.len()..word_1.len();
2077                }
2078
2079                <|editable_region_end|>
2080            "},
2081            cx,
2082        )
2083        .await;
2084        assert_eq!(
2085            edits,
2086            [
2087                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2088                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2089            ]
2090        );
2091
2092        let edits = edits_for_prediction(
2093            indoc! {"
2094                fn main() {
2095                    let story = \"the quick\"
2096                }
2097            "},
2098            indoc! {"
2099                <|editable_region_start|>
2100                fn main() {
2101                    let story = \"the quick brown fox jumps over the lazy dog\";
2102                }
2103
2104                <|editable_region_end|>
2105            "},
2106            cx,
2107        )
2108        .await;
2109        assert_eq!(
2110            edits,
2111            [
2112                (
2113                    Point::new(1, 26)..Point::new(1, 26),
2114                    " brown fox jumps over the lazy dog".to_string()
2115                ),
2116                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2117            ]
2118        );
2119    }
2120
2121    #[gpui::test]
2122    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2123        cx.update(|cx| {
2124            let settings_store = SettingsStore::test(cx);
2125            cx.set_global(settings_store);
2126            client::init_settings(cx);
2127        });
2128
2129        let buffer_content = "lorem\n";
2130        let completion_response = indoc! {"
2131            ```animals.js
2132            <|start_of_file|>
2133            <|editable_region_start|>
2134            lorem
2135            ipsum
2136            <|editable_region_end|>
2137            ```"};
2138
2139        let http_client = FakeHttpClient::create(move |req| async move {
2140            match (req.method(), req.uri().path()) {
2141                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2142                    .status(200)
2143                    .body(
2144                        serde_json::to_string(&CreateLlmTokenResponse {
2145                            token: LlmToken("the-llm-token".to_string()),
2146                        })
2147                        .unwrap()
2148                        .into(),
2149                    )
2150                    .unwrap()),
2151                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2152                    .status(200)
2153                    .body(
2154                        serde_json::to_string(&PredictEditsResponse {
2155                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2156                                .unwrap(),
2157                            output_excerpt: completion_response.to_string(),
2158                        })
2159                        .unwrap()
2160                        .into(),
2161                    )
2162                    .unwrap()),
2163                _ => Ok(http_client::Response::builder()
2164                    .status(404)
2165                    .body("Not Found".into())
2166                    .unwrap()),
2167            }
2168        });
2169
2170        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2171        cx.update(|cx| {
2172            RefreshLlmTokenListener::register(client.clone(), cx);
2173        });
2174        // Construct the fake server to authenticate.
2175        let _server = FakeServer::for_client(42, &client, cx).await;
2176        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2177        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2178
2179        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2180        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2181        let completion_task = zeta.update(cx, |zeta, cx| {
2182            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2183        });
2184
2185        let completion = completion_task.await.unwrap().unwrap();
2186        buffer.update(cx, |buffer, cx| {
2187            buffer.edit(completion.edits.iter().cloned(), None, cx)
2188        });
2189        assert_eq!(
2190            buffer.read_with(cx, |buffer, _| buffer.text()),
2191            "lorem\nipsum"
2192        );
2193    }
2194
2195    async fn edits_for_prediction(
2196        buffer_content: &str,
2197        completion_response: &str,
2198        cx: &mut TestAppContext,
2199    ) -> Vec<(Range<Point>, String)> {
2200        let completion_response = completion_response.to_string();
2201        let http_client = FakeHttpClient::create(move |req| {
2202            let completion = completion_response.clone();
2203            async move {
2204                match (req.method(), req.uri().path()) {
2205                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2206                        .status(200)
2207                        .body(
2208                            serde_json::to_string(&CreateLlmTokenResponse {
2209                                token: LlmToken("the-llm-token".to_string()),
2210                            })
2211                            .unwrap()
2212                            .into(),
2213                        )
2214                        .unwrap()),
2215                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2216                        .status(200)
2217                        .body(
2218                            serde_json::to_string(&PredictEditsResponse {
2219                                request_id: Uuid::new_v4(),
2220                                output_excerpt: completion,
2221                            })
2222                            .unwrap()
2223                            .into(),
2224                        )
2225                        .unwrap()),
2226                    _ => Ok(http_client::Response::builder()
2227                        .status(404)
2228                        .body("Not Found".into())
2229                        .unwrap()),
2230                }
2231            }
2232        });
2233
2234        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2235        cx.update(|cx| {
2236            RefreshLlmTokenListener::register(client.clone(), cx);
2237        });
2238        // Construct the fake server to authenticate.
2239        let _server = FakeServer::for_client(42, &client, cx).await;
2240        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2241        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2242
2243        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2244        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2245        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2246        let completion_task = zeta.update(cx, |zeta, cx| {
2247            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2248        });
2249
2250        let completion = completion_task.await.unwrap().unwrap();
2251        completion
2252            .edits
2253            .iter()
2254            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2255            .collect::<Vec<_>>()
2256    }
2257
2258    fn to_completion_edits(
2259        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2260        buffer: &Entity<Buffer>,
2261        cx: &App,
2262    ) -> Vec<(Range<Anchor>, String)> {
2263        let buffer = buffer.read(cx);
2264        iterator
2265            .into_iter()
2266            .map(|(range, text)| {
2267                (
2268                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2269                    text,
2270                )
2271            })
2272            .collect()
2273    }
2274
2275    fn from_completion_edits(
2276        editor_edits: &[(Range<Anchor>, String)],
2277        buffer: &Entity<Buffer>,
2278        cx: &App,
2279    ) -> Vec<(Range<usize>, String)> {
2280        let buffer = buffer.read(cx);
2281        editor_edits
2282            .iter()
2283            .map(|(range, text)| {
2284                (
2285                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2286                    text.clone(),
2287                )
2288            })
2289            .collect()
2290    }
2291
2292    #[ctor::ctor]
2293    fn init_logger() {
2294        zlog::init_test();
2295    }
2296}