zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use edit_prediction::DataCollectionState;
  12pub use init::*;
  13use license_detection::LicenseDetectionWatcher;
  14pub use rate_completion_modal::*;
  15
  16use anyhow::{Context as _, Result, anyhow};
  17use arrayvec::ArrayVec;
  18use client::{Client, EditPredictionUsage, UserStore};
  19use cloud_llm_client::{
  20    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  21    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
  22};
  23use collections::{HashMap, HashSet, VecDeque};
  24use futures::AsyncReadExt;
  25use gpui::{
  26    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  27    SharedString, Subscription, Task, actions,
  28};
  29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  30use input_excerpt::excerpt_for_cursor_position;
  31use language::{
  32    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  33};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use project::{Project, ProjectPath};
  36use release_channel::AppVersion;
  37use settings::WorktreeId;
  38use std::collections::hash_map;
  39use std::mem;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    ops::Range,
  46    path::Path,
  47    rc::Rc,
  48    sync::Arc,
  49    time::{Duration, Instant},
  50};
  51use telemetry_events::EditPredictionRating;
  52use thiserror::Error;
  53use util::ResultExt;
  54use uuid::Uuid;
  55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  56use worktree::Worktree;
  57
  58const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  59const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  60const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  61const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  62const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  63const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  64
  65const MAX_CONTEXT_TOKENS: usize = 150;
  66const MAX_REWRITE_TOKENS: usize = 350;
  67const MAX_EVENT_TOKENS: usize = 500;
  68
  69/// Maximum number of events to track.
  70const MAX_EVENT_COUNT: usize = 16;
  71
  72actions!(
  73    edit_prediction,
  74    [
  75        /// Clears the edit prediction history.
  76        ClearHistory
  77    ]
  78);
  79
  80#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  81pub struct EditPredictionId(Uuid);
  82
  83impl From<EditPredictionId> for gpui::ElementId {
  84    fn from(value: EditPredictionId) -> Self {
  85        gpui::ElementId::Uuid(value.0)
  86    }
  87}
  88
  89impl std::fmt::Display for EditPredictionId {
  90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  91        write!(f, "{}", self.0)
  92    }
  93}
  94
  95struct ZedPredictUpsell;
  96
  97impl Dismissable for ZedPredictUpsell {
  98    const KEY: &'static str = "dismissed-edit-predict-upsell";
  99
 100    fn dismissed() -> bool {
 101        // To make this backwards compatible with older versions of Zed, we
 102        // check if the user has seen the previous Edit Prediction Onboarding
 103        // before, by checking the data collection choice which was written to
 104        // the database once the user clicked on "Accept and Enable"
 105        if KEY_VALUE_STORE
 106            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 107            .log_err()
 108            .is_some_and(|s| s.is_some())
 109        {
 110            return true;
 111        }
 112
 113        KEY_VALUE_STORE
 114            .read_kvp(Self::KEY)
 115            .log_err()
 116            .is_some_and(|s| s.is_some())
 117    }
 118}
 119
 120pub fn should_show_upsell_modal() -> bool {
 121    !ZedPredictUpsell::dismissed()
 122}
 123
 124#[derive(Clone)]
 125struct ZetaGlobal(Entity<Zeta>);
 126
 127impl Global for ZetaGlobal {}
 128
 129#[derive(Clone)]
 130pub struct EditPrediction {
 131    id: EditPredictionId,
 132    path: Arc<Path>,
 133    excerpt_range: Range<usize>,
 134    cursor_offset: usize,
 135    edits: Arc<[(Range<Anchor>, String)]>,
 136    snapshot: BufferSnapshot,
 137    edit_preview: EditPreview,
 138    input_outline: Arc<str>,
 139    input_events: Arc<str>,
 140    input_excerpt: Arc<str>,
 141    output_excerpt: Arc<str>,
 142    buffer_snapshotted_at: Instant,
 143    response_received_at: Instant,
 144}
 145
 146impl EditPrediction {
 147    fn latency(&self) -> Duration {
 148        self.response_received_at
 149            .duration_since(self.buffer_snapshotted_at)
 150    }
 151
 152    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 153        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 154    }
 155}
 156
 157fn interpolate(
 158    old_snapshot: &BufferSnapshot,
 159    new_snapshot: &BufferSnapshot,
 160    current_edits: Arc<[(Range<Anchor>, String)]>,
 161) -> Option<Vec<(Range<Anchor>, String)>> {
 162    let mut edits = Vec::new();
 163
 164    let mut model_edits = current_edits.iter().peekable();
 165    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 166        while let Some((model_old_range, _)) = model_edits.peek() {
 167            let model_old_range = model_old_range.to_offset(old_snapshot);
 168            if model_old_range.end < user_edit.old.start {
 169                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 170                edits.push((model_old_range.clone(), model_new_text.clone()));
 171            } else {
 172                break;
 173            }
 174        }
 175
 176        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 177            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 178            if user_edit.old == model_old_offset_range {
 179                let user_new_text = new_snapshot
 180                    .text_for_range(user_edit.new.clone())
 181                    .collect::<String>();
 182
 183                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 184                    if !model_suffix.is_empty() {
 185                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 186                        edits.push((anchor..anchor, model_suffix.to_string()));
 187                    }
 188
 189                    model_edits.next();
 190                    continue;
 191                }
 192            }
 193        }
 194
 195        return None;
 196    }
 197
 198    edits.extend(model_edits.cloned());
 199
 200    if edits.is_empty() { None } else { Some(edits) }
 201}
 202
 203impl std::fmt::Debug for EditPrediction {
 204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 205        f.debug_struct("EditPrediction")
 206            .field("id", &self.id)
 207            .field("path", &self.path)
 208            .field("edits", &self.edits)
 209            .finish_non_exhaustive()
 210    }
 211}
 212
 213pub struct Zeta {
 214    projects: HashMap<EntityId, ZetaProject>,
 215    client: Arc<Client>,
 216    shown_completions: VecDeque<EditPrediction>,
 217    rated_completions: HashSet<EditPredictionId>,
 218    data_collection_choice: DataCollectionChoice,
 219    llm_token: LlmApiToken,
 220    _llm_token_subscription: Subscription,
 221    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 222    update_required: bool,
 223    user_store: Entity<UserStore>,
 224    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 225}
 226
 227struct ZetaProject {
 228    events: VecDeque<Event>,
 229    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 230}
 231
 232impl Zeta {
 233    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 234        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 235    }
 236
 237    pub fn register(
 238        worktree: Option<Entity<Worktree>>,
 239        client: Arc<Client>,
 240        user_store: Entity<UserStore>,
 241        cx: &mut App,
 242    ) -> Entity<Self> {
 243        let this = Self::global(cx).unwrap_or_else(|| {
 244            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 245            cx.set_global(ZetaGlobal(entity.clone()));
 246            entity
 247        });
 248
 249        this.update(cx, move |this, cx| {
 250            if let Some(worktree) = worktree {
 251                let worktree_id = worktree.read(cx).id();
 252                this.license_detection_watchers
 253                    .entry(worktree_id)
 254                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 255            }
 256        });
 257
 258        this
 259    }
 260
 261    pub fn clear_history(&mut self) {
 262        for zeta_project in self.projects.values_mut() {
 263            zeta_project.events.clear();
 264        }
 265    }
 266
 267    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 268        self.user_store.read(cx).edit_prediction_usage()
 269    }
 270
 271    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 272        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 273        let data_collection_choice = Self::load_data_collection_choice();
 274        Self {
 275            projects: HashMap::default(),
 276            client,
 277            shown_completions: VecDeque::new(),
 278            rated_completions: HashSet::default(),
 279            data_collection_choice,
 280            llm_token: LlmApiToken::default(),
 281            _llm_token_subscription: cx.subscribe(
 282                &refresh_llm_token_listener,
 283                |this, _listener, _event, cx| {
 284                    let client = this.client.clone();
 285                    let llm_token = this.llm_token.clone();
 286                    cx.spawn(async move |_this, _cx| {
 287                        llm_token.refresh(&client).await?;
 288                        anyhow::Ok(())
 289                    })
 290                    .detach_and_log_err(cx);
 291                },
 292            ),
 293            update_required: false,
 294            license_detection_watchers: HashMap::default(),
 295            user_store,
 296        }
 297    }
 298
 299    fn get_or_init_zeta_project(
 300        &mut self,
 301        project: &Entity<Project>,
 302        cx: &mut Context<Self>,
 303    ) -> &mut ZetaProject {
 304        let project_id = project.entity_id();
 305        match self.projects.entry(project_id) {
 306            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 307            hash_map::Entry::Vacant(entry) => {
 308                cx.observe_release(project, move |this, _, _cx| {
 309                    this.projects.remove(&project_id);
 310                })
 311                .detach();
 312                entry.insert(ZetaProject {
 313                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 314                    registered_buffers: HashMap::default(),
 315                })
 316            }
 317        }
 318    }
 319
 320    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 321        let events = &mut zeta_project.events;
 322
 323        if let Some(Event::BufferChange {
 324            new_snapshot: last_new_snapshot,
 325            timestamp: last_timestamp,
 326            ..
 327        }) = events.back_mut()
 328        {
 329            // Coalesce edits for the same buffer when they happen one after the other.
 330            let Event::BufferChange {
 331                old_snapshot,
 332                new_snapshot,
 333                timestamp,
 334            } = &event;
 335
 336            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 337                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 338                && old_snapshot.version == last_new_snapshot.version
 339            {
 340                *last_new_snapshot = new_snapshot.clone();
 341                *last_timestamp = *timestamp;
 342                return;
 343            }
 344        }
 345
 346        if events.len() >= MAX_EVENT_COUNT {
 347            // These are halved instead of popping to improve prompt caching.
 348            events.drain(..MAX_EVENT_COUNT / 2);
 349        }
 350
 351        events.push_back(event);
 352    }
 353
 354    pub fn register_buffer(
 355        &mut self,
 356        buffer: &Entity<Buffer>,
 357        project: &Entity<Project>,
 358        cx: &mut Context<Self>,
 359    ) {
 360        let zeta_project = self.get_or_init_zeta_project(project, cx);
 361        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 362    }
 363
 364    fn register_buffer_impl<'a>(
 365        zeta_project: &'a mut ZetaProject,
 366        buffer: &Entity<Buffer>,
 367        project: &Entity<Project>,
 368        cx: &mut Context<Self>,
 369    ) -> &'a mut RegisteredBuffer {
 370        let buffer_id = buffer.entity_id();
 371        match zeta_project.registered_buffers.entry(buffer_id) {
 372            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 373            hash_map::Entry::Vacant(entry) => {
 374                let snapshot = buffer.read(cx).snapshot();
 375                let project_entity_id = project.entity_id();
 376                entry.insert(RegisteredBuffer {
 377                    snapshot,
 378                    _subscriptions: [
 379                        cx.subscribe(buffer, {
 380                            let project = project.downgrade();
 381                            move |this, buffer, event, cx| {
 382                                if let language::BufferEvent::Edited = event
 383                                    && let Some(project) = project.upgrade()
 384                                {
 385                                    this.report_changes_for_buffer(&buffer, &project, cx);
 386                                }
 387                            }
 388                        }),
 389                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 390                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 391                            else {
 392                                return;
 393                            };
 394                            zeta_project.registered_buffers.remove(&buffer_id);
 395                        }),
 396                    ],
 397                })
 398            }
 399        }
 400    }
 401
 402    fn request_completion_impl<F, R>(
 403        &mut self,
 404        project: &Entity<Project>,
 405        buffer: &Entity<Buffer>,
 406        cursor: language::Anchor,
 407        cx: &mut Context<Self>,
 408        perform_predict_edits: F,
 409    ) -> Task<Result<Option<EditPrediction>>>
 410    where
 411        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 412        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 413            + Send
 414            + 'static,
 415    {
 416        let buffer = buffer.clone();
 417        let buffer_snapshotted_at = Instant::now();
 418        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 419        let zeta = cx.entity();
 420        let client = self.client.clone();
 421        let llm_token = self.llm_token.clone();
 422        let app_version = AppVersion::global(cx);
 423
 424        let zeta_project = self.get_or_init_zeta_project(project, cx);
 425        let mut events = Vec::with_capacity(zeta_project.events.len());
 426        events.extend(zeta_project.events.iter().cloned());
 427        let events = Arc::new(events);
 428
 429        let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 430            let can_collect_file = self.can_collect_file(file, cx);
 431            let git_info = if can_collect_file {
 432                git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 433            } else {
 434                None
 435            };
 436            (git_info, can_collect_file)
 437        } else {
 438            (None, false)
 439        };
 440
 441        let full_path: Arc<Path> = snapshot
 442            .file()
 443            .map(|f| Arc::from(f.full_path(cx).as_path()))
 444            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 445        let full_path_str = full_path.to_string_lossy().to_string();
 446        let cursor_point = cursor.to_point(&snapshot);
 447        let cursor_offset = cursor_point.to_offset(&snapshot);
 448        let prompt_for_events = {
 449            let events = events.clone();
 450            move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 451        };
 452        let gather_task = gather_context(
 453            full_path_str,
 454            &snapshot,
 455            cursor_point,
 456            prompt_for_events,
 457            cx,
 458        );
 459
 460        cx.spawn(async move |this, cx| {
 461            let GatherContextOutput {
 462                mut body,
 463                editable_range,
 464                included_events_count,
 465            } = gather_task.await?;
 466            let done_gathering_context_at = Instant::now();
 467
 468            let included_events = &events[events.len() - included_events_count..events.len()];
 469            body.can_collect_data = can_collect_file
 470                && this
 471                    .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
 472                    .unwrap_or(false);
 473            if body.can_collect_data {
 474                body.git_info = git_info;
 475            }
 476
 477            log::debug!(
 478                "Events:\n{}\nExcerpt:\n{:?}",
 479                body.input_events,
 480                body.input_excerpt
 481            );
 482
 483            let input_outline = body.outline.clone().unwrap_or_default();
 484            let input_events = body.input_events.clone();
 485            let input_excerpt = body.input_excerpt.clone();
 486
 487            let response = perform_predict_edits(PerformPredictEditsParams {
 488                client,
 489                llm_token,
 490                app_version,
 491                body,
 492            })
 493            .await;
 494            let (response, usage) = match response {
 495                Ok(response) => response,
 496                Err(err) => {
 497                    if err.is::<ZedUpdateRequiredError>() {
 498                        cx.update(|cx| {
 499                            zeta.update(cx, |zeta, _cx| {
 500                                zeta.update_required = true;
 501                            });
 502
 503                            let error_message: SharedString = err.to_string().into();
 504                            show_app_notification(
 505                                NotificationId::unique::<ZedUpdateRequiredError>(),
 506                                cx,
 507                                move |cx| {
 508                                    cx.new(|cx| {
 509                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 510                                            .with_link_button(
 511                                                "Update Zed",
 512                                                "https://zed.dev/releases",
 513                                            )
 514                                    })
 515                                },
 516                            );
 517                        })
 518                        .ok();
 519                    }
 520
 521                    return Err(err);
 522                }
 523            };
 524
 525            let received_response_at = Instant::now();
 526            log::debug!("completion response: {}", &response.output_excerpt);
 527
 528            if let Some(usage) = usage {
 529                this.update(cx, |this, cx| {
 530                    this.user_store.update(cx, |user_store, cx| {
 531                        user_store.update_edit_prediction_usage(usage, cx);
 532                    });
 533                })
 534                .ok();
 535            }
 536
 537            let edit_prediction = Self::process_completion_response(
 538                response,
 539                buffer,
 540                &snapshot,
 541                editable_range,
 542                cursor_offset,
 543                full_path,
 544                input_outline,
 545                input_events,
 546                input_excerpt,
 547                buffer_snapshotted_at,
 548                cx,
 549            )
 550            .await;
 551
 552            let finished_at = Instant::now();
 553
 554            // record latency for ~1% of requests
 555            if rand::random::<u8>() <= 2 {
 556                telemetry::event!(
 557                    "Edit Prediction Request",
 558                    context_latency = done_gathering_context_at
 559                        .duration_since(buffer_snapshotted_at)
 560                        .as_millis(),
 561                    request_latency = received_response_at
 562                        .duration_since(done_gathering_context_at)
 563                        .as_millis(),
 564                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 565                );
 566            }
 567
 568            edit_prediction
 569        })
 570    }
 571
 572    #[cfg(any(test, feature = "test-support"))]
 573    pub fn fake_completion(
 574        &mut self,
 575        project: &Entity<Project>,
 576        buffer: &Entity<Buffer>,
 577        position: language::Anchor,
 578        response: PredictEditsResponse,
 579        cx: &mut Context<Self>,
 580    ) -> Task<Result<Option<EditPrediction>>> {
 581        self.request_completion_impl(project, buffer, position, cx, |_params| {
 582            std::future::ready(Ok((response, None)))
 583        })
 584    }
 585
 586    pub fn request_completion(
 587        &mut self,
 588        project: &Entity<Project>,
 589        buffer: &Entity<Buffer>,
 590        position: language::Anchor,
 591        cx: &mut Context<Self>,
 592    ) -> Task<Result<Option<EditPrediction>>> {
 593        self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
 594    }
 595
 596    pub fn perform_predict_edits(
 597        params: PerformPredictEditsParams,
 598    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 599        async move {
 600            let PerformPredictEditsParams {
 601                client,
 602                llm_token,
 603                app_version,
 604                body,
 605                ..
 606            } = params;
 607
 608            let http_client = client.http_client();
 609            let mut token = llm_token.acquire(&client).await?;
 610            let mut did_retry = false;
 611
 612            loop {
 613                let request_builder = http_client::Request::builder().method(Method::POST);
 614                let request_builder =
 615                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 616                        request_builder.uri(predict_edits_url)
 617                    } else {
 618                        request_builder.uri(
 619                            http_client
 620                                .build_zed_llm_url("/predict_edits/v2", &[])?
 621                                .as_ref(),
 622                        )
 623                    };
 624                let request = request_builder
 625                    .header("Content-Type", "application/json")
 626                    .header("Authorization", format!("Bearer {}", token))
 627                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 628                    .body(serde_json::to_string(&body)?.into())?;
 629
 630                let mut response = http_client.send(request).await?;
 631
 632                if let Some(minimum_required_version) = response
 633                    .headers()
 634                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 635                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 636                {
 637                    anyhow::ensure!(
 638                        app_version >= minimum_required_version,
 639                        ZedUpdateRequiredError {
 640                            minimum_version: minimum_required_version
 641                        }
 642                    );
 643                }
 644
 645                if response.status().is_success() {
 646                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 647
 648                    let mut body = String::new();
 649                    response.body_mut().read_to_string(&mut body).await?;
 650                    return Ok((serde_json::from_str(&body)?, usage));
 651                } else if !did_retry
 652                    && response
 653                        .headers()
 654                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 655                        .is_some()
 656                {
 657                    did_retry = true;
 658                    token = llm_token.refresh(&client).await?;
 659                } else {
 660                    let mut body = String::new();
 661                    response.body_mut().read_to_string(&mut body).await?;
 662                    anyhow::bail!(
 663                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 664                        response.status(),
 665                        body
 666                    );
 667                }
 668            }
 669        }
 670    }
 671
 672    fn accept_edit_prediction(
 673        &mut self,
 674        request_id: EditPredictionId,
 675        cx: &mut Context<Self>,
 676    ) -> Task<Result<()>> {
 677        let client = self.client.clone();
 678        let llm_token = self.llm_token.clone();
 679        let app_version = AppVersion::global(cx);
 680        cx.spawn(async move |this, cx| {
 681            let http_client = client.http_client();
 682            let mut response = llm_token_retry(&llm_token, &client, |token| {
 683                let request_builder = http_client::Request::builder().method(Method::POST);
 684                let request_builder =
 685                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 686                        request_builder.uri(accept_prediction_url)
 687                    } else {
 688                        request_builder.uri(
 689                            http_client
 690                                .build_zed_llm_url("/predict_edits/accept", &[])?
 691                                .as_ref(),
 692                        )
 693                    };
 694                Ok(request_builder
 695                    .header("Content-Type", "application/json")
 696                    .header("Authorization", format!("Bearer {}", token))
 697                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 698                    .body(
 699                        serde_json::to_string(&AcceptEditPredictionBody {
 700                            request_id: request_id.0,
 701                        })?
 702                        .into(),
 703                    )?)
 704            })
 705            .await?;
 706
 707            if let Some(minimum_required_version) = response
 708                .headers()
 709                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 710                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 711                && app_version < minimum_required_version
 712            {
 713                return Err(anyhow!(ZedUpdateRequiredError {
 714                    minimum_version: minimum_required_version
 715                }));
 716            }
 717
 718            if response.status().is_success() {
 719                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 720                    this.update(cx, |this, cx| {
 721                        this.user_store.update(cx, |user_store, cx| {
 722                            user_store.update_edit_prediction_usage(usage, cx);
 723                        });
 724                    })?;
 725                }
 726
 727                Ok(())
 728            } else {
 729                let mut body = String::new();
 730                response.body_mut().read_to_string(&mut body).await?;
 731                Err(anyhow!(
 732                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 733                    response.status(),
 734                    body
 735                ))
 736            }
 737        })
 738    }
 739
 740    fn process_completion_response(
 741        prediction_response: PredictEditsResponse,
 742        buffer: Entity<Buffer>,
 743        snapshot: &BufferSnapshot,
 744        editable_range: Range<usize>,
 745        cursor_offset: usize,
 746        path: Arc<Path>,
 747        input_outline: String,
 748        input_events: String,
 749        input_excerpt: String,
 750        buffer_snapshotted_at: Instant,
 751        cx: &AsyncApp,
 752    ) -> Task<Result<Option<EditPrediction>>> {
 753        let snapshot = snapshot.clone();
 754        let request_id = prediction_response.request_id;
 755        let output_excerpt = prediction_response.output_excerpt;
 756        cx.spawn(async move |cx| {
 757            let output_excerpt: Arc<str> = output_excerpt.into();
 758
 759            let edits: Arc<[(Range<Anchor>, String)]> = cx
 760                .background_spawn({
 761                    let output_excerpt = output_excerpt.clone();
 762                    let editable_range = editable_range.clone();
 763                    let snapshot = snapshot.clone();
 764                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 765                })
 766                .await?
 767                .into();
 768
 769            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 770                let edits = edits.clone();
 771                |buffer, cx| {
 772                    let new_snapshot = buffer.snapshot();
 773                    let edits: Arc<[(Range<Anchor>, String)]> =
 774                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 775                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 776                }
 777            })?
 778            else {
 779                return anyhow::Ok(None);
 780            };
 781
 782            let edit_preview = edit_preview.await;
 783
 784            Ok(Some(EditPrediction {
 785                id: EditPredictionId(request_id),
 786                path,
 787                excerpt_range: editable_range,
 788                cursor_offset,
 789                edits,
 790                edit_preview,
 791                snapshot,
 792                input_outline: input_outline.into(),
 793                input_events: input_events.into(),
 794                input_excerpt: input_excerpt.into(),
 795                output_excerpt,
 796                buffer_snapshotted_at,
 797                response_received_at: Instant::now(),
 798            }))
 799        })
 800    }
 801
 802    fn parse_edits(
 803        output_excerpt: Arc<str>,
 804        editable_range: Range<usize>,
 805        snapshot: &BufferSnapshot,
 806    ) -> Result<Vec<(Range<Anchor>, String)>> {
 807        let content = output_excerpt.replace(CURSOR_MARKER, "");
 808
 809        let start_markers = content
 810            .match_indices(EDITABLE_REGION_START_MARKER)
 811            .collect::<Vec<_>>();
 812        anyhow::ensure!(
 813            start_markers.len() == 1,
 814            "expected exactly one start marker, found {}",
 815            start_markers.len()
 816        );
 817
 818        let end_markers = content
 819            .match_indices(EDITABLE_REGION_END_MARKER)
 820            .collect::<Vec<_>>();
 821        anyhow::ensure!(
 822            end_markers.len() == 1,
 823            "expected exactly one end marker, found {}",
 824            end_markers.len()
 825        );
 826
 827        let sof_markers = content
 828            .match_indices(START_OF_FILE_MARKER)
 829            .collect::<Vec<_>>();
 830        anyhow::ensure!(
 831            sof_markers.len() <= 1,
 832            "expected at most one start-of-file marker, found {}",
 833            sof_markers.len()
 834        );
 835
 836        let codefence_start = start_markers[0].0;
 837        let content = &content[codefence_start..];
 838
 839        let newline_ix = content.find('\n').context("could not find newline")?;
 840        let content = &content[newline_ix + 1..];
 841
 842        let codefence_end = content
 843            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 844            .context("could not find end marker")?;
 845        let new_text = &content[..codefence_end];
 846
 847        let old_text = snapshot
 848            .text_for_range(editable_range.clone())
 849            .collect::<String>();
 850
 851        Ok(Self::compute_edits(
 852            old_text,
 853            new_text,
 854            editable_range.start,
 855            snapshot,
 856        ))
 857    }
 858
 859    pub fn compute_edits(
 860        old_text: String,
 861        new_text: &str,
 862        offset: usize,
 863        snapshot: &BufferSnapshot,
 864    ) -> Vec<(Range<Anchor>, String)> {
 865        text_diff(&old_text, new_text)
 866            .into_iter()
 867            .map(|(mut old_range, new_text)| {
 868                old_range.start += offset;
 869                old_range.end += offset;
 870
 871                let prefix_len = common_prefix(
 872                    snapshot.chars_for_range(old_range.clone()),
 873                    new_text.chars(),
 874                );
 875                old_range.start += prefix_len;
 876
 877                let suffix_len = common_prefix(
 878                    snapshot.reversed_chars_for_range(old_range.clone()),
 879                    new_text[prefix_len..].chars().rev(),
 880                );
 881                old_range.end = old_range.end.saturating_sub(suffix_len);
 882
 883                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 884                let range = if old_range.is_empty() {
 885                    let anchor = snapshot.anchor_after(old_range.start);
 886                    anchor..anchor
 887                } else {
 888                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 889                };
 890                (range, new_text)
 891            })
 892            .collect()
 893    }
 894
 895    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 896        self.rated_completions.contains(&completion_id)
 897    }
 898
 899    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 900        self.shown_completions.push_front(completion.clone());
 901        if self.shown_completions.len() > 50 {
 902            let completion = self.shown_completions.pop_back().unwrap();
 903            self.rated_completions.remove(&completion.id);
 904        }
 905        cx.notify();
 906    }
 907
 908    pub fn rate_completion(
 909        &mut self,
 910        completion: &EditPrediction,
 911        rating: EditPredictionRating,
 912        feedback: String,
 913        cx: &mut Context<Self>,
 914    ) {
 915        self.rated_completions.insert(completion.id);
 916        telemetry::event!(
 917            "Edit Prediction Rated",
 918            rating,
 919            input_events = completion.input_events,
 920            input_excerpt = completion.input_excerpt,
 921            input_outline = completion.input_outline,
 922            output_excerpt = completion.output_excerpt,
 923            feedback
 924        );
 925        self.client.telemetry().flush_events().detach();
 926        cx.notify();
 927    }
 928
 929    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 930        self.shown_completions.iter()
 931    }
 932
 933    pub fn shown_completions_len(&self) -> usize {
 934        self.shown_completions.len()
 935    }
 936
 937    fn report_changes_for_buffer(
 938        &mut self,
 939        buffer: &Entity<Buffer>,
 940        project: &Entity<Project>,
 941        cx: &mut Context<Self>,
 942    ) -> BufferSnapshot {
 943        let zeta_project = self.get_or_init_zeta_project(project, cx);
 944        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 945
 946        let new_snapshot = buffer.read(cx).snapshot();
 947        if new_snapshot.version != registered_buffer.snapshot.version {
 948            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 949            Self::push_event(
 950                zeta_project,
 951                Event::BufferChange {
 952                    old_snapshot,
 953                    new_snapshot: new_snapshot.clone(),
 954                    timestamp: Instant::now(),
 955                },
 956            );
 957        }
 958
 959        new_snapshot
 960    }
 961
 962    fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
 963        self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
 964    }
 965
 966    fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
 967        if !self.data_collection_choice.is_enabled() {
 968            return false;
 969        }
 970        let mut last_checked_file = None;
 971        for event in events {
 972            match event {
 973                Event::BufferChange {
 974                    old_snapshot,
 975                    new_snapshot,
 976                    ..
 977                } => {
 978                    if let Some(old_file) = old_snapshot.file()
 979                        && let Some(new_file) = new_snapshot.file()
 980                    {
 981                        if let Some(last_checked_file) = last_checked_file
 982                            && Arc::ptr_eq(last_checked_file, old_file)
 983                            && Arc::ptr_eq(last_checked_file, new_file)
 984                        {
 985                            continue;
 986                        }
 987                        if !self.can_collect_file(old_file, cx) {
 988                            return false;
 989                        }
 990                        if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
 991                        {
 992                            return false;
 993                        }
 994                        last_checked_file = Some(new_file);
 995                    } else {
 996                        return false;
 997                    }
 998                }
 999            }
1000        }
1001        true
1002    }
1003
1004    fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1005        if !file.is_local() || file.is_private() {
1006            return false;
1007        }
1008        self.license_detection_watchers
1009            .get(&file.worktree_id(cx))
1010            .is_some_and(|watcher| watcher.is_project_open_source())
1011    }
1012
1013    fn load_data_collection_choice() -> DataCollectionChoice {
1014        let choice = KEY_VALUE_STORE
1015            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1016            .log_err()
1017            .flatten();
1018
1019        match choice.as_deref() {
1020            Some("true") => DataCollectionChoice::Enabled,
1021            Some("false") => DataCollectionChoice::Disabled,
1022            Some(_) => {
1023                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1024                DataCollectionChoice::NotAnswered
1025            }
1026            None => DataCollectionChoice::NotAnswered,
1027        }
1028    }
1029
1030    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1031        self.data_collection_choice = self.data_collection_choice.toggle();
1032        let new_choice = self.data_collection_choice;
1033        db::write_and_log(cx, move || {
1034            KEY_VALUE_STORE.write_kvp(
1035                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1036                new_choice.is_enabled().to_string(),
1037            )
1038        });
1039    }
1040}
1041
1042pub struct PerformPredictEditsParams {
1043    pub client: Arc<Client>,
1044    pub llm_token: LlmApiToken,
1045    pub app_version: SemanticVersion,
1046    pub body: PredictEditsBody,
1047}
1048
1049#[derive(Error, Debug)]
1050#[error(
1051    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1052)]
1053pub struct ZedUpdateRequiredError {
1054    minimum_version: SemanticVersion,
1055}
1056
1057fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1058    a.zip(b)
1059        .take_while(|(a, b)| a == b)
1060        .map(|(a, _)| a.len_utf8())
1061        .sum()
1062}
1063
1064fn git_info_for_file(
1065    project: &Entity<Project>,
1066    project_path: &ProjectPath,
1067    cx: &App,
1068) -> Option<PredictEditsGitInfo> {
1069    let git_store = project.read(cx).git_store().read(cx);
1070    if let Some((repository, _repo_path)) =
1071        git_store.repository_and_path_for_project_path(project_path, cx)
1072    {
1073        let repository = repository.read(cx);
1074        let head_sha = repository
1075            .head_commit
1076            .as_ref()
1077            .map(|head_commit| head_commit.sha.to_string());
1078        let remote_origin_url = repository.remote_origin_url.clone();
1079        let remote_upstream_url = repository.remote_upstream_url.clone();
1080        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1081            return None;
1082        }
1083        Some(PredictEditsGitInfo {
1084            head_sha,
1085            remote_origin_url,
1086            remote_upstream_url,
1087        })
1088    } else {
1089        None
1090    }
1091}
1092
1093pub struct GatherContextOutput {
1094    pub body: PredictEditsBody,
1095    pub editable_range: Range<usize>,
1096    pub included_events_count: usize,
1097}
1098
1099pub fn gather_context(
1100    full_path_str: String,
1101    snapshot: &BufferSnapshot,
1102    cursor_point: language::Point,
1103    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1104    cx: &App,
1105) -> Task<Result<GatherContextOutput>> {
1106    cx.background_spawn({
1107        let snapshot = snapshot.clone();
1108        async move {
1109            let input_excerpt = excerpt_for_cursor_position(
1110                cursor_point,
1111                &full_path_str,
1112                &snapshot,
1113                MAX_REWRITE_TOKENS,
1114                MAX_CONTEXT_TOKENS,
1115            );
1116            let (input_events, included_events_count) = prompt_for_events();
1117            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1118
1119            let body = PredictEditsBody {
1120                input_events,
1121                input_excerpt: input_excerpt.prompt,
1122                can_collect_data: false,
1123                diagnostic_groups: None,
1124                git_info: None,
1125                outline: None,
1126                speculated_output: None,
1127            };
1128
1129            Ok(GatherContextOutput {
1130                body,
1131                editable_range,
1132                included_events_count,
1133            })
1134        }
1135    })
1136}
1137
1138fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1139    let mut result = String::new();
1140    for (ix, event) in events.iter().rev().enumerate() {
1141        let event_string = event.to_prompt();
1142        let event_tokens = guess_token_count(event_string.len());
1143        if event_tokens > remaining_tokens {
1144            return (result, ix);
1145        }
1146
1147        if !result.is_empty() {
1148            result.insert_str(0, "\n\n");
1149        }
1150        result.insert_str(0, &event_string);
1151        remaining_tokens -= event_tokens;
1152    }
1153    return (result, events.len());
1154}
1155
1156struct RegisteredBuffer {
1157    snapshot: BufferSnapshot,
1158    _subscriptions: [gpui::Subscription; 2],
1159}
1160
1161#[derive(Clone)]
1162pub enum Event {
1163    BufferChange {
1164        old_snapshot: BufferSnapshot,
1165        new_snapshot: BufferSnapshot,
1166        timestamp: Instant,
1167    },
1168}
1169
1170impl Event {
1171    fn to_prompt(&self) -> String {
1172        match self {
1173            Event::BufferChange {
1174                old_snapshot,
1175                new_snapshot,
1176                ..
1177            } => {
1178                let mut prompt = String::new();
1179
1180                let old_path = old_snapshot
1181                    .file()
1182                    .map(|f| f.path().as_ref())
1183                    .unwrap_or(Path::new("untitled"));
1184                let new_path = new_snapshot
1185                    .file()
1186                    .map(|f| f.path().as_ref())
1187                    .unwrap_or(Path::new("untitled"));
1188                if old_path != new_path {
1189                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1190                }
1191
1192                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1193                if !diff.is_empty() {
1194                    write!(
1195                        prompt,
1196                        "User edited {:?}:\n```diff\n{}\n```",
1197                        new_path, diff
1198                    )
1199                    .unwrap();
1200                }
1201
1202                prompt
1203            }
1204        }
1205    }
1206}
1207
1208#[derive(Debug, Clone)]
1209struct CurrentEditPrediction {
1210    buffer_id: EntityId,
1211    completion: EditPrediction,
1212}
1213
1214impl CurrentEditPrediction {
1215    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1216        if self.buffer_id != old_completion.buffer_id {
1217            return true;
1218        }
1219
1220        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1221            return true;
1222        };
1223        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1224            return false;
1225        };
1226
1227        if old_edits.len() == 1 && new_edits.len() == 1 {
1228            let (old_range, old_text) = &old_edits[0];
1229            let (new_range, new_text) = &new_edits[0];
1230            new_range == old_range && new_text.starts_with(old_text)
1231        } else {
1232            true
1233        }
1234    }
1235}
1236
1237struct PendingCompletion {
1238    id: usize,
1239    _task: Task<()>,
1240}
1241
1242#[derive(Debug, Clone, Copy)]
1243pub enum DataCollectionChoice {
1244    NotAnswered,
1245    Enabled,
1246    Disabled,
1247}
1248
1249impl DataCollectionChoice {
1250    pub fn is_enabled(self) -> bool {
1251        match self {
1252            Self::Enabled => true,
1253            Self::NotAnswered | Self::Disabled => false,
1254        }
1255    }
1256
1257    pub fn is_answered(self) -> bool {
1258        match self {
1259            Self::Enabled | Self::Disabled => true,
1260            Self::NotAnswered => false,
1261        }
1262    }
1263
1264    #[must_use]
1265    pub fn toggle(&self) -> DataCollectionChoice {
1266        match self {
1267            Self::Enabled => Self::Disabled,
1268            Self::Disabled => Self::Enabled,
1269            Self::NotAnswered => Self::Enabled,
1270        }
1271    }
1272}
1273
1274impl From<bool> for DataCollectionChoice {
1275    fn from(value: bool) -> Self {
1276        match value {
1277            true => DataCollectionChoice::Enabled,
1278            false => DataCollectionChoice::Disabled,
1279        }
1280    }
1281}
1282
1283async fn llm_token_retry(
1284    llm_token: &LlmApiToken,
1285    client: &Arc<Client>,
1286    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1287) -> Result<Response<AsyncBody>> {
1288    let mut did_retry = false;
1289    let http_client = client.http_client();
1290    let mut token = llm_token.acquire(client).await?;
1291    loop {
1292        let request = build_request(token.clone())?;
1293        let response = http_client.send(request).await?;
1294
1295        if !did_retry
1296            && !response.status().is_success()
1297            && response
1298                .headers()
1299                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1300                .is_some()
1301        {
1302            did_retry = true;
1303            token = llm_token.refresh(client).await?;
1304            continue;
1305        }
1306
1307        return Ok(response);
1308    }
1309}
1310
1311pub struct ZetaEditPredictionProvider {
1312    zeta: Entity<Zeta>,
1313    singleton_buffer: Option<Entity<Buffer>>,
1314    pending_completions: ArrayVec<PendingCompletion, 2>,
1315    next_pending_completion_id: usize,
1316    current_completion: Option<CurrentEditPrediction>,
1317    last_request_timestamp: Instant,
1318}
1319
1320impl ZetaEditPredictionProvider {
1321    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1322
1323    pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
1324        Self {
1325            zeta,
1326            singleton_buffer,
1327            pending_completions: ArrayVec::new(),
1328            next_pending_completion_id: 0,
1329            current_completion: None,
1330            last_request_timestamp: Instant::now(),
1331        }
1332    }
1333}
1334
1335impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1336    fn name() -> &'static str {
1337        "zed-predict"
1338    }
1339
1340    fn display_name() -> &'static str {
1341        "Zed's Edit Predictions"
1342    }
1343
1344    fn show_completions_in_menu() -> bool {
1345        true
1346    }
1347
1348    fn show_tab_accept_marker() -> bool {
1349        true
1350    }
1351
1352    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1353        if let Some(buffer) = &self.singleton_buffer
1354            && let Some(file) = buffer.read(cx).file()
1355        {
1356            let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1357            if self.zeta.read(cx).data_collection_choice.is_enabled() {
1358                DataCollectionState::Enabled {
1359                    is_project_open_source,
1360                }
1361            } else {
1362                DataCollectionState::Disabled {
1363                    is_project_open_source,
1364                }
1365            }
1366        } else {
1367            return DataCollectionState::Disabled {
1368                is_project_open_source: false,
1369            };
1370        }
1371    }
1372
1373    fn toggle_data_collection(&mut self, cx: &mut App) {
1374        self.zeta
1375            .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1376    }
1377
1378    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1379        self.zeta.read(cx).usage(cx)
1380    }
1381
1382    fn is_enabled(
1383        &self,
1384        _buffer: &Entity<Buffer>,
1385        _cursor_position: language::Anchor,
1386        _cx: &App,
1387    ) -> bool {
1388        true
1389    }
1390    fn is_refreshing(&self) -> bool {
1391        !self.pending_completions.is_empty()
1392    }
1393
1394    fn refresh(
1395        &mut self,
1396        project: Option<Entity<Project>>,
1397        buffer: Entity<Buffer>,
1398        position: language::Anchor,
1399        _debounce: bool,
1400        cx: &mut Context<Self>,
1401    ) {
1402        if self.zeta.read(cx).update_required {
1403            return;
1404        }
1405        let Some(project) = project else {
1406            return;
1407        };
1408
1409        if self
1410            .zeta
1411            .read(cx)
1412            .user_store
1413            .read_with(cx, |user_store, _cx| {
1414                user_store.account_too_young() || user_store.has_overdue_invoices()
1415            })
1416        {
1417            return;
1418        }
1419
1420        if let Some(current_completion) = self.current_completion.as_ref() {
1421            let snapshot = buffer.read(cx).snapshot();
1422            if current_completion
1423                .completion
1424                .interpolate(&snapshot)
1425                .is_some()
1426            {
1427                return;
1428            }
1429        }
1430
1431        let pending_completion_id = self.next_pending_completion_id;
1432        self.next_pending_completion_id += 1;
1433        let last_request_timestamp = self.last_request_timestamp;
1434
1435        let task = cx.spawn(async move |this, cx| {
1436            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1437                .checked_duration_since(Instant::now())
1438            {
1439                cx.background_executor().timer(timeout).await;
1440            }
1441
1442            let completion_request = this.update(cx, |this, cx| {
1443                this.last_request_timestamp = Instant::now();
1444                this.zeta.update(cx, |zeta, cx| {
1445                    zeta.request_completion(&project, &buffer, position, cx)
1446                })
1447            });
1448
1449            let completion = match completion_request {
1450                Ok(completion_request) => {
1451                    let completion_request = completion_request.await;
1452                    completion_request.map(|c| {
1453                        c.map(|completion| CurrentEditPrediction {
1454                            buffer_id: buffer.entity_id(),
1455                            completion,
1456                        })
1457                    })
1458                }
1459                Err(error) => Err(error),
1460            };
1461            let Some(new_completion) = completion
1462                .context("edit prediction failed")
1463                .log_err()
1464                .flatten()
1465            else {
1466                this.update(cx, |this, cx| {
1467                    if this.pending_completions[0].id == pending_completion_id {
1468                        this.pending_completions.remove(0);
1469                    } else {
1470                        this.pending_completions.clear();
1471                    }
1472
1473                    cx.notify();
1474                })
1475                .ok();
1476                return;
1477            };
1478
1479            this.update(cx, |this, cx| {
1480                if this.pending_completions[0].id == pending_completion_id {
1481                    this.pending_completions.remove(0);
1482                } else {
1483                    this.pending_completions.clear();
1484                }
1485
1486                if let Some(old_completion) = this.current_completion.as_ref() {
1487                    let snapshot = buffer.read(cx).snapshot();
1488                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1489                        this.zeta.update(cx, |zeta, cx| {
1490                            zeta.completion_shown(&new_completion.completion, cx);
1491                        });
1492                        this.current_completion = Some(new_completion);
1493                    }
1494                } else {
1495                    this.zeta.update(cx, |zeta, cx| {
1496                        zeta.completion_shown(&new_completion.completion, cx);
1497                    });
1498                    this.current_completion = Some(new_completion);
1499                }
1500
1501                cx.notify();
1502            })
1503            .ok();
1504        });
1505
1506        // We always maintain at most two pending completions. When we already
1507        // have two, we replace the newest one.
1508        if self.pending_completions.len() <= 1 {
1509            self.pending_completions.push(PendingCompletion {
1510                id: pending_completion_id,
1511                _task: task,
1512            });
1513        } else if self.pending_completions.len() == 2 {
1514            self.pending_completions.pop();
1515            self.pending_completions.push(PendingCompletion {
1516                id: pending_completion_id,
1517                _task: task,
1518            });
1519        }
1520    }
1521
1522    fn cycle(
1523        &mut self,
1524        _buffer: Entity<Buffer>,
1525        _cursor_position: language::Anchor,
1526        _direction: edit_prediction::Direction,
1527        _cx: &mut Context<Self>,
1528    ) {
1529        // Right now we don't support cycling.
1530    }
1531
1532    fn accept(&mut self, cx: &mut Context<Self>) {
1533        let completion_id = self
1534            .current_completion
1535            .as_ref()
1536            .map(|completion| completion.completion.id);
1537        if let Some(completion_id) = completion_id {
1538            self.zeta
1539                .update(cx, |zeta, cx| {
1540                    zeta.accept_edit_prediction(completion_id, cx)
1541                })
1542                .detach();
1543        }
1544        self.pending_completions.clear();
1545    }
1546
1547    fn discard(&mut self, _cx: &mut Context<Self>) {
1548        self.pending_completions.clear();
1549        self.current_completion.take();
1550    }
1551
1552    fn suggest(
1553        &mut self,
1554        buffer: &Entity<Buffer>,
1555        cursor_position: language::Anchor,
1556        cx: &mut Context<Self>,
1557    ) -> Option<edit_prediction::EditPrediction> {
1558        let CurrentEditPrediction {
1559            buffer_id,
1560            completion,
1561            ..
1562        } = self.current_completion.as_mut()?;
1563
1564        // Invalidate previous completion if it was generated for a different buffer.
1565        if *buffer_id != buffer.entity_id() {
1566            self.current_completion.take();
1567            return None;
1568        }
1569
1570        let buffer = buffer.read(cx);
1571        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1572            self.current_completion.take();
1573            return None;
1574        };
1575
1576        let cursor_row = cursor_position.to_point(buffer).row;
1577        let (closest_edit_ix, (closest_edit_range, _)) =
1578            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1579                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1580                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1581                cmp::min(distance_from_start, distance_from_end)
1582            })?;
1583
1584        let mut edit_start_ix = closest_edit_ix;
1585        for (range, _) in edits[..edit_start_ix].iter().rev() {
1586            let distance_from_closest_edit =
1587                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1588            if distance_from_closest_edit <= 1 {
1589                edit_start_ix -= 1;
1590            } else {
1591                break;
1592            }
1593        }
1594
1595        let mut edit_end_ix = closest_edit_ix + 1;
1596        for (range, _) in &edits[edit_end_ix..] {
1597            let distance_from_closest_edit =
1598                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1599            if distance_from_closest_edit <= 1 {
1600                edit_end_ix += 1;
1601            } else {
1602                break;
1603            }
1604        }
1605
1606        Some(edit_prediction::EditPrediction {
1607            id: Some(completion.id.to_string().into()),
1608            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1609            edit_preview: Some(completion.edit_preview.clone()),
1610        })
1611    }
1612}
1613
1614/// Typical number of string bytes per token for the purposes of limiting model input. This is
1615/// intentionally low to err on the side of underestimating limits.
1616const BYTES_PER_TOKEN_GUESS: usize = 3;
1617
1618fn guess_token_count(bytes: usize) -> usize {
1619    bytes / BYTES_PER_TOKEN_GUESS
1620}
1621
1622#[cfg(test)]
1623mod tests {
1624    use client::test::FakeServer;
1625    use clock::FakeSystemClock;
1626    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1627    use gpui::TestAppContext;
1628    use http_client::FakeHttpClient;
1629    use indoc::indoc;
1630    use language::Point;
1631    use parking_lot::Mutex;
1632    use serde_json::json;
1633    use settings::SettingsStore;
1634    use util::path;
1635
1636    use super::*;
1637
1638    const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1639
1640    #[gpui::test]
1641    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1642        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1643        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1644            to_completion_edits(
1645                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1646                &buffer,
1647                cx,
1648            )
1649            .into()
1650        });
1651
1652        let edit_preview = cx
1653            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1654            .await;
1655
1656        let completion = EditPrediction {
1657            edits,
1658            edit_preview,
1659            path: Path::new("").into(),
1660            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1661            id: EditPredictionId(Uuid::new_v4()),
1662            excerpt_range: 0..0,
1663            cursor_offset: 0,
1664            input_outline: "".into(),
1665            input_events: "".into(),
1666            input_excerpt: "".into(),
1667            output_excerpt: "".into(),
1668            buffer_snapshotted_at: Instant::now(),
1669            response_received_at: Instant::now(),
1670        };
1671
1672        cx.update(|cx| {
1673            assert_eq!(
1674                from_completion_edits(
1675                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1676                    &buffer,
1677                    cx
1678                ),
1679                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1680            );
1681
1682            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1683            assert_eq!(
1684                from_completion_edits(
1685                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1686                    &buffer,
1687                    cx
1688                ),
1689                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1690            );
1691
1692            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1693            assert_eq!(
1694                from_completion_edits(
1695                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1696                    &buffer,
1697                    cx
1698                ),
1699                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1700            );
1701
1702            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1703            assert_eq!(
1704                from_completion_edits(
1705                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1706                    &buffer,
1707                    cx
1708                ),
1709                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1710            );
1711
1712            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1713            assert_eq!(
1714                from_completion_edits(
1715                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1716                    &buffer,
1717                    cx
1718                ),
1719                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1720            );
1721
1722            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1723            assert_eq!(
1724                from_completion_edits(
1725                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1726                    &buffer,
1727                    cx
1728                ),
1729                vec![(9..11, "".to_string())]
1730            );
1731
1732            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1733            assert_eq!(
1734                from_completion_edits(
1735                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1736                    &buffer,
1737                    cx
1738                ),
1739                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1740            );
1741
1742            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1743            assert_eq!(
1744                from_completion_edits(
1745                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1746                    &buffer,
1747                    cx
1748                ),
1749                vec![(4..4, "M".to_string())]
1750            );
1751
1752            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1753            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1754        })
1755    }
1756
1757    #[gpui::test]
1758    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1759        init_test(cx);
1760
1761        assert_eq!(
1762            apply_edit_prediction(
1763                indoc! {"
1764                    fn main() {
1765                        let word_1 = \"lorem\";
1766                        let range = word.len()..word.len();
1767                    }
1768                "},
1769                indoc! {"
1770                    <|editable_region_start|>
1771                    fn main() {
1772                        let word_1 = \"lorem\";
1773                        let range = word_1.len()..word_1.len();
1774                    }
1775
1776                    <|editable_region_end|>
1777                "},
1778                cx,
1779            )
1780            .await,
1781            indoc! {"
1782                fn main() {
1783                    let word_1 = \"lorem\";
1784                    let range = word_1.len()..word_1.len();
1785                }
1786            "},
1787        );
1788
1789        assert_eq!(
1790            apply_edit_prediction(
1791                indoc! {"
1792                    fn main() {
1793                        let story = \"the quick\"
1794                    }
1795                "},
1796                indoc! {"
1797                    <|editable_region_start|>
1798                    fn main() {
1799                        let story = \"the quick brown fox jumps over the lazy dog\";
1800                    }
1801
1802                    <|editable_region_end|>
1803                "},
1804                cx,
1805            )
1806            .await,
1807            indoc! {"
1808                fn main() {
1809                    let story = \"the quick brown fox jumps over the lazy dog\";
1810                }
1811            "},
1812        );
1813    }
1814
1815    #[gpui::test]
1816    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1817        init_test(cx);
1818
1819        let buffer_content = "lorem\n";
1820        let completion_response = indoc! {"
1821            ```animals.js
1822            <|start_of_file|>
1823            <|editable_region_start|>
1824            lorem
1825            ipsum
1826            <|editable_region_end|>
1827            ```"};
1828
1829        assert_eq!(
1830            apply_edit_prediction(buffer_content, completion_response, cx).await,
1831            "lorem\nipsum"
1832        );
1833    }
1834
1835    #[gpui::test]
1836    async fn test_can_collect_data(cx: &mut TestAppContext) {
1837        init_test(cx);
1838
1839        let fs = project::FakeFs::new(cx.executor());
1840        fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1841            .await;
1842
1843        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1844        let buffer = project
1845            .update(cx, |project, cx| {
1846                project.open_local_buffer(path!("/project/src/main.rs"), cx)
1847            })
1848            .await
1849            .unwrap();
1850
1851        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1852        zeta.update(cx, |zeta, _cx| {
1853            zeta.data_collection_choice = DataCollectionChoice::Enabled
1854        });
1855
1856        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1857        assert_eq!(
1858            captured_request.lock().clone().unwrap().can_collect_data,
1859            true
1860        );
1861
1862        zeta.update(cx, |zeta, _cx| {
1863            zeta.data_collection_choice = DataCollectionChoice::Disabled
1864        });
1865
1866        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1867        assert_eq!(
1868            captured_request.lock().clone().unwrap().can_collect_data,
1869            false
1870        );
1871    }
1872
1873    #[gpui::test]
1874    async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1875        init_test(cx);
1876
1877        let fs = project::FakeFs::new(cx.executor());
1878        let project = Project::test(fs.clone(), [], cx).await;
1879
1880        let buffer = cx.new(|_cx| {
1881            Buffer::remote(
1882                language::BufferId::new(1).unwrap(),
1883                1,
1884                language::Capability::ReadWrite,
1885                "fn main() {\n    println!(\"Hello\");\n}",
1886            )
1887        });
1888
1889        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1890        zeta.update(cx, |zeta, _cx| {
1891            zeta.data_collection_choice = DataCollectionChoice::Enabled
1892        });
1893
1894        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1895        assert_eq!(
1896            captured_request.lock().clone().unwrap().can_collect_data,
1897            false
1898        );
1899    }
1900
1901    #[gpui::test]
1902    async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1903        init_test(cx);
1904
1905        let fs = project::FakeFs::new(cx.executor());
1906        fs.insert_tree(
1907            path!("/project"),
1908            json!({
1909                "LICENSE": BSD_0_TXT,
1910                ".env": "SECRET_KEY=secret"
1911            }),
1912        )
1913        .await;
1914
1915        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1916        let buffer = project
1917            .update(cx, |project, cx| {
1918                project.open_local_buffer("/project/.env", cx)
1919            })
1920            .await
1921            .unwrap();
1922
1923        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1924        zeta.update(cx, |zeta, _cx| {
1925            zeta.data_collection_choice = DataCollectionChoice::Enabled
1926        });
1927
1928        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1929        assert_eq!(
1930            captured_request.lock().clone().unwrap().can_collect_data,
1931            false
1932        );
1933    }
1934
1935    #[gpui::test]
1936    async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1937        init_test(cx);
1938
1939        let fs = project::FakeFs::new(cx.executor());
1940        let project = Project::test(fs.clone(), [], cx).await;
1941        let buffer = cx.new(|cx| Buffer::local("", cx));
1942
1943        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1944        zeta.update(cx, |zeta, _cx| {
1945            zeta.data_collection_choice = DataCollectionChoice::Enabled
1946        });
1947
1948        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1949        assert_eq!(
1950            captured_request.lock().clone().unwrap().can_collect_data,
1951            false
1952        );
1953    }
1954
1955    #[gpui::test]
1956    async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1957        init_test(cx);
1958
1959        let fs = project::FakeFs::new(cx.executor());
1960        fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1961            .await;
1962
1963        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1964        let buffer = project
1965            .update(cx, |project, cx| {
1966                project.open_local_buffer("/project/main.rs", cx)
1967            })
1968            .await
1969            .unwrap();
1970
1971        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1972        zeta.update(cx, |zeta, _cx| {
1973            zeta.data_collection_choice = DataCollectionChoice::Enabled
1974        });
1975
1976        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1977        assert_eq!(
1978            captured_request.lock().clone().unwrap().can_collect_data,
1979            false
1980        );
1981    }
1982
1983    #[gpui::test]
1984    async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1985        init_test(cx);
1986
1987        let fs = project::FakeFs::new(cx.executor());
1988        fs.insert_tree(
1989            path!("/open_source_worktree"),
1990            json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1991        )
1992        .await;
1993        fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1994            .await;
1995
1996        let project = Project::test(
1997            fs.clone(),
1998            [
1999                path!("/open_source_worktree").as_ref(),
2000                path!("/closed_source_worktree").as_ref(),
2001            ],
2002            cx,
2003        )
2004        .await;
2005        let buffer = project
2006            .update(cx, |project, cx| {
2007                project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2008            })
2009            .await
2010            .unwrap();
2011
2012        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2013        zeta.update(cx, |zeta, _cx| {
2014            zeta.data_collection_choice = DataCollectionChoice::Enabled
2015        });
2016
2017        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2018        assert_eq!(
2019            captured_request.lock().clone().unwrap().can_collect_data,
2020            true
2021        );
2022
2023        let closed_source_file = project
2024            .update(cx, |project, cx| {
2025                let worktree2 = project
2026                    .worktree_for_root_name("closed_source_worktree", cx)
2027                    .unwrap();
2028                worktree2.update(cx, |worktree2, cx| {
2029                    worktree2.load_file(Path::new("main.rs"), cx)
2030                })
2031            })
2032            .await
2033            .unwrap()
2034            .file;
2035
2036        buffer.update(cx, |buffer, cx| {
2037            buffer.file_updated(closed_source_file, cx);
2038        });
2039
2040        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2041        assert_eq!(
2042            captured_request.lock().clone().unwrap().can_collect_data,
2043            false
2044        );
2045    }
2046
2047    #[gpui::test]
2048    async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2049        init_test(cx);
2050
2051        let fs = project::FakeFs::new(cx.executor());
2052        fs.insert_tree(
2053            path!("/worktree1"),
2054            json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2055        )
2056        .await;
2057        fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2058            .await;
2059
2060        let project = Project::test(
2061            fs.clone(),
2062            [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2063            cx,
2064        )
2065        .await;
2066        let buffer = project
2067            .update(cx, |project, cx| {
2068                project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2069            })
2070            .await
2071            .unwrap();
2072        let private_buffer = project
2073            .update(cx, |project, cx| {
2074                project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2075            })
2076            .await
2077            .unwrap();
2078
2079        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2080        zeta.update(cx, |zeta, _cx| {
2081            zeta.data_collection_choice = DataCollectionChoice::Enabled
2082        });
2083
2084        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2085        assert_eq!(
2086            captured_request.lock().clone().unwrap().can_collect_data,
2087            true
2088        );
2089
2090        // this has a side effect of registering the buffer to watch for edits
2091        run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2092        assert_eq!(
2093            captured_request.lock().clone().unwrap().can_collect_data,
2094            false
2095        );
2096
2097        private_buffer.update(cx, |private_buffer, cx| {
2098            private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2099        });
2100
2101        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2102        assert_eq!(
2103            captured_request.lock().clone().unwrap().can_collect_data,
2104            false
2105        );
2106
2107        // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2108        // included
2109        buffer.update(cx, |buffer, cx| {
2110            buffer.edit(
2111                [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2112                None,
2113                cx,
2114            );
2115        });
2116
2117        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2118        assert_eq!(
2119            captured_request.lock().clone().unwrap().can_collect_data,
2120            true
2121        );
2122    }
2123
2124    fn init_test(cx: &mut TestAppContext) {
2125        cx.update(|cx| {
2126            let settings_store = SettingsStore::test(cx);
2127            cx.set_global(settings_store);
2128            language::init(cx);
2129            client::init_settings(cx);
2130            Project::init_settings(cx);
2131        });
2132    }
2133
2134    async fn apply_edit_prediction(
2135        buffer_content: &str,
2136        completion_response: &str,
2137        cx: &mut TestAppContext,
2138    ) -> String {
2139        let fs = project::FakeFs::new(cx.executor());
2140        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2141        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2142        let (zeta, _, response) = make_test_zeta(&project, cx).await;
2143        *response.lock() = completion_response.to_string();
2144        let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2145        buffer.update(cx, |buffer, cx| {
2146            buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2147        });
2148        buffer.read_with(cx, |buffer, _| buffer.text())
2149    }
2150
2151    async fn run_edit_prediction(
2152        buffer: &Entity<Buffer>,
2153        project: &Entity<Project>,
2154        zeta: &Entity<Zeta>,
2155        cx: &mut TestAppContext,
2156    ) -> EditPrediction {
2157        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2158        zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2159        cx.background_executor.run_until_parked();
2160        let completion_task = zeta.update(cx, |zeta, cx| {
2161            zeta.request_completion(&project, buffer, cursor, cx)
2162        });
2163        completion_task.await.unwrap().unwrap()
2164    }
2165
2166    async fn make_test_zeta(
2167        project: &Entity<Project>,
2168        cx: &mut TestAppContext,
2169    ) -> (
2170        Entity<Zeta>,
2171        Arc<Mutex<Option<PredictEditsBody>>>,
2172        Arc<Mutex<String>>,
2173    ) {
2174        let default_response = indoc! {"
2175            ```main.rs
2176            <|start_of_file|>
2177            <|editable_region_start|>
2178            hello world
2179            <|editable_region_end|>
2180            ```"
2181        };
2182        let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2183        let completion_response: Arc<Mutex<String>> =
2184            Arc::new(Mutex::new(default_response.to_string()));
2185        let http_client = FakeHttpClient::create({
2186            let captured_request = captured_request.clone();
2187            let completion_response = completion_response.clone();
2188            move |req| {
2189                let captured_request = captured_request.clone();
2190                let completion_response = completion_response.clone();
2191                async move {
2192                    match (req.method(), req.uri().path()) {
2193                        (&Method::POST, "/client/llm_tokens") => {
2194                            Ok(http_client::Response::builder()
2195                                .status(200)
2196                                .body(
2197                                    serde_json::to_string(&CreateLlmTokenResponse {
2198                                        token: LlmToken("the-llm-token".to_string()),
2199                                    })
2200                                    .unwrap()
2201                                    .into(),
2202                                )
2203                                .unwrap())
2204                        }
2205                        (&Method::POST, "/predict_edits/v2") => {
2206                            let mut request_body = String::new();
2207                            req.into_body().read_to_string(&mut request_body).await?;
2208                            *captured_request.lock() =
2209                                Some(serde_json::from_str(&request_body).unwrap());
2210                            Ok(http_client::Response::builder()
2211                                .status(200)
2212                                .body(
2213                                    serde_json::to_string(&PredictEditsResponse {
2214                                        request_id: Uuid::new_v4(),
2215                                        output_excerpt: completion_response.lock().clone(),
2216                                    })
2217                                    .unwrap()
2218                                    .into(),
2219                                )
2220                                .unwrap())
2221                        }
2222                        _ => Ok(http_client::Response::builder()
2223                            .status(404)
2224                            .body("Not Found".into())
2225                            .unwrap()),
2226                    }
2227                }
2228            }
2229        });
2230
2231        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2232        cx.update(|cx| {
2233            RefreshLlmTokenListener::register(client.clone(), cx);
2234        });
2235        let _server = FakeServer::for_client(42, &client, cx).await;
2236
2237        let zeta = cx.new(|cx| {
2238            let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2239
2240            let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2241            for worktree in worktrees {
2242                let worktree_id = worktree.read(cx).id();
2243                zeta.license_detection_watchers
2244                    .entry(worktree_id)
2245                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2246            }
2247
2248            zeta
2249        });
2250
2251        (zeta, captured_request, completion_response)
2252    }
2253
2254    fn to_completion_edits(
2255        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2256        buffer: &Entity<Buffer>,
2257        cx: &App,
2258    ) -> Vec<(Range<Anchor>, String)> {
2259        let buffer = buffer.read(cx);
2260        iterator
2261            .into_iter()
2262            .map(|(range, text)| {
2263                (
2264                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2265                    text,
2266                )
2267            })
2268            .collect()
2269    }
2270
2271    fn from_completion_edits(
2272        editor_edits: &[(Range<Anchor>, String)],
2273        buffer: &Entity<Buffer>,
2274        cx: &App,
2275    ) -> Vec<(Range<usize>, String)> {
2276        let buffer = buffer.read(cx);
2277        editor_edits
2278            .iter()
2279            .map(|(range, text)| {
2280                (
2281                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2282                    text.clone(),
2283                )
2284            })
2285            .collect()
2286    }
2287
2288    #[ctor::ctor]
2289    fn init_logger() {
2290        zlog::init_test();
2291    }
2292}