zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use edit_prediction::DataCollectionState;
  12pub use init::*;
  13use license_detection::LicenseDetectionWatcher;
  14pub use rate_completion_modal::*;
  15
  16use anyhow::{Context as _, Result, anyhow};
  17use arrayvec::ArrayVec;
  18use client::{Client, EditPredictionUsage, UserStore};
  19use cloud_llm_client::{
  20    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  21    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
  22};
  23use collections::{HashMap, HashSet, VecDeque};
  24use futures::AsyncReadExt;
  25use gpui::{
  26    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  27    SharedString, Subscription, Task, actions,
  28};
  29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  30use input_excerpt::excerpt_for_cursor_position;
  31use language::{
  32    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  33};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use project::{Project, ProjectPath};
  36use release_channel::AppVersion;
  37use settings::WorktreeId;
  38use std::collections::hash_map;
  39use std::mem;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    ops::Range,
  46    path::Path,
  47    rc::Rc,
  48    sync::Arc,
  49    time::{Duration, Instant},
  50};
  51use telemetry_events::EditPredictionRating;
  52use thiserror::Error;
  53use util::ResultExt;
  54use util::rel_path::RelPath;
  55use uuid::Uuid;
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57use worktree::Worktree;
  58
  59const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  60const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  61const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  62const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  63const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  64const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  65
  66const MAX_CONTEXT_TOKENS: usize = 150;
  67const MAX_REWRITE_TOKENS: usize = 350;
  68const MAX_EVENT_TOKENS: usize = 500;
  69
  70/// Maximum number of events to track.
  71const MAX_EVENT_COUNT: usize = 16;
  72
  73actions!(
  74    edit_prediction,
  75    [
  76        /// Clears the edit prediction history.
  77        ClearHistory
  78    ]
  79);
  80
  81#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  82pub struct EditPredictionId(Uuid);
  83
  84impl From<EditPredictionId> for gpui::ElementId {
  85    fn from(value: EditPredictionId) -> Self {
  86        gpui::ElementId::Uuid(value.0)
  87    }
  88}
  89
  90impl std::fmt::Display for EditPredictionId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96struct ZedPredictUpsell;
  97
  98impl Dismissable for ZedPredictUpsell {
  99    const KEY: &'static str = "dismissed-edit-predict-upsell";
 100
 101    fn dismissed() -> bool {
 102        // To make this backwards compatible with older versions of Zed, we
 103        // check if the user has seen the previous Edit Prediction Onboarding
 104        // before, by checking the data collection choice which was written to
 105        // the database once the user clicked on "Accept and Enable"
 106        if KEY_VALUE_STORE
 107            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 108            .log_err()
 109            .is_some_and(|s| s.is_some())
 110        {
 111            return true;
 112        }
 113
 114        KEY_VALUE_STORE
 115            .read_kvp(Self::KEY)
 116            .log_err()
 117            .is_some_and(|s| s.is_some())
 118    }
 119}
 120
 121pub fn should_show_upsell_modal() -> bool {
 122    !ZedPredictUpsell::dismissed()
 123}
 124
 125#[derive(Clone)]
 126struct ZetaGlobal(Entity<Zeta>);
 127
 128impl Global for ZetaGlobal {}
 129
 130#[derive(Clone)]
 131pub struct EditPrediction {
 132    id: EditPredictionId,
 133    path: Arc<Path>,
 134    excerpt_range: Range<usize>,
 135    cursor_offset: usize,
 136    edits: Arc<[(Range<Anchor>, String)]>,
 137    snapshot: BufferSnapshot,
 138    edit_preview: EditPreview,
 139    input_outline: Arc<str>,
 140    input_events: Arc<str>,
 141    input_excerpt: Arc<str>,
 142    output_excerpt: Arc<str>,
 143    buffer_snapshotted_at: Instant,
 144    response_received_at: Instant,
 145}
 146
 147impl EditPrediction {
 148    fn latency(&self) -> Duration {
 149        self.response_received_at
 150            .duration_since(self.buffer_snapshotted_at)
 151    }
 152
 153    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 154        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 155    }
 156}
 157
 158fn interpolate(
 159    old_snapshot: &BufferSnapshot,
 160    new_snapshot: &BufferSnapshot,
 161    current_edits: Arc<[(Range<Anchor>, String)]>,
 162) -> Option<Vec<(Range<Anchor>, String)>> {
 163    let mut edits = Vec::new();
 164
 165    let mut model_edits = current_edits.iter().peekable();
 166    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 167        while let Some((model_old_range, _)) = model_edits.peek() {
 168            let model_old_range = model_old_range.to_offset(old_snapshot);
 169            if model_old_range.end < user_edit.old.start {
 170                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 171                edits.push((model_old_range.clone(), model_new_text.clone()));
 172            } else {
 173                break;
 174            }
 175        }
 176
 177        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 178            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 179            if user_edit.old == model_old_offset_range {
 180                let user_new_text = new_snapshot
 181                    .text_for_range(user_edit.new.clone())
 182                    .collect::<String>();
 183
 184                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 185                    if !model_suffix.is_empty() {
 186                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 187                        edits.push((anchor..anchor, model_suffix.to_string()));
 188                    }
 189
 190                    model_edits.next();
 191                    continue;
 192                }
 193            }
 194        }
 195
 196        return None;
 197    }
 198
 199    edits.extend(model_edits.cloned());
 200
 201    if edits.is_empty() { None } else { Some(edits) }
 202}
 203
 204impl std::fmt::Debug for EditPrediction {
 205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 206        f.debug_struct("EditPrediction")
 207            .field("id", &self.id)
 208            .field("path", &self.path)
 209            .field("edits", &self.edits)
 210            .finish_non_exhaustive()
 211    }
 212}
 213
 214pub struct Zeta {
 215    projects: HashMap<EntityId, ZetaProject>,
 216    client: Arc<Client>,
 217    shown_completions: VecDeque<EditPrediction>,
 218    rated_completions: HashSet<EditPredictionId>,
 219    data_collection_choice: DataCollectionChoice,
 220    llm_token: LlmApiToken,
 221    _llm_token_subscription: Subscription,
 222    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 223    update_required: bool,
 224    user_store: Entity<UserStore>,
 225    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 226}
 227
 228struct ZetaProject {
 229    events: VecDeque<Event>,
 230    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 231}
 232
 233impl Zeta {
 234    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 235        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 236    }
 237
 238    pub fn register(
 239        worktree: Option<Entity<Worktree>>,
 240        client: Arc<Client>,
 241        user_store: Entity<UserStore>,
 242        cx: &mut App,
 243    ) -> Entity<Self> {
 244        let this = Self::global(cx).unwrap_or_else(|| {
 245            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 246            cx.set_global(ZetaGlobal(entity.clone()));
 247            entity
 248        });
 249
 250        this.update(cx, move |this, cx| {
 251            if let Some(worktree) = worktree {
 252                let worktree_id = worktree.read(cx).id();
 253                this.license_detection_watchers
 254                    .entry(worktree_id)
 255                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 256            }
 257        });
 258
 259        this
 260    }
 261
 262    pub fn clear_history(&mut self) {
 263        for zeta_project in self.projects.values_mut() {
 264            zeta_project.events.clear();
 265        }
 266    }
 267
 268    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 269        self.user_store.read(cx).edit_prediction_usage()
 270    }
 271
 272    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 273        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 274        let data_collection_choice = Self::load_data_collection_choice();
 275        Self {
 276            projects: HashMap::default(),
 277            client,
 278            shown_completions: VecDeque::new(),
 279            rated_completions: HashSet::default(),
 280            data_collection_choice,
 281            llm_token: LlmApiToken::default(),
 282            _llm_token_subscription: cx.subscribe(
 283                &refresh_llm_token_listener,
 284                |this, _listener, _event, cx| {
 285                    let client = this.client.clone();
 286                    let llm_token = this.llm_token.clone();
 287                    cx.spawn(async move |_this, _cx| {
 288                        llm_token.refresh(&client).await?;
 289                        anyhow::Ok(())
 290                    })
 291                    .detach_and_log_err(cx);
 292                },
 293            ),
 294            update_required: false,
 295            license_detection_watchers: HashMap::default(),
 296            user_store,
 297        }
 298    }
 299
 300    fn get_or_init_zeta_project(
 301        &mut self,
 302        project: &Entity<Project>,
 303        cx: &mut Context<Self>,
 304    ) -> &mut ZetaProject {
 305        let project_id = project.entity_id();
 306        match self.projects.entry(project_id) {
 307            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 308            hash_map::Entry::Vacant(entry) => {
 309                cx.observe_release(project, move |this, _, _cx| {
 310                    this.projects.remove(&project_id);
 311                })
 312                .detach();
 313                entry.insert(ZetaProject {
 314                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 315                    registered_buffers: HashMap::default(),
 316                })
 317            }
 318        }
 319    }
 320
 321    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 322        let events = &mut zeta_project.events;
 323
 324        if let Some(Event::BufferChange {
 325            new_snapshot: last_new_snapshot,
 326            timestamp: last_timestamp,
 327            ..
 328        }) = events.back_mut()
 329        {
 330            // Coalesce edits for the same buffer when they happen one after the other.
 331            let Event::BufferChange {
 332                old_snapshot,
 333                new_snapshot,
 334                timestamp,
 335            } = &event;
 336
 337            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 338                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 339                && old_snapshot.version == last_new_snapshot.version
 340            {
 341                *last_new_snapshot = new_snapshot.clone();
 342                *last_timestamp = *timestamp;
 343                return;
 344            }
 345        }
 346
 347        if events.len() >= MAX_EVENT_COUNT {
 348            // These are halved instead of popping to improve prompt caching.
 349            events.drain(..MAX_EVENT_COUNT / 2);
 350        }
 351
 352        events.push_back(event);
 353    }
 354
 355    pub fn register_buffer(
 356        &mut self,
 357        buffer: &Entity<Buffer>,
 358        project: &Entity<Project>,
 359        cx: &mut Context<Self>,
 360    ) {
 361        let zeta_project = self.get_or_init_zeta_project(project, cx);
 362        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 363    }
 364
 365    fn register_buffer_impl<'a>(
 366        zeta_project: &'a mut ZetaProject,
 367        buffer: &Entity<Buffer>,
 368        project: &Entity<Project>,
 369        cx: &mut Context<Self>,
 370    ) -> &'a mut RegisteredBuffer {
 371        let buffer_id = buffer.entity_id();
 372        match zeta_project.registered_buffers.entry(buffer_id) {
 373            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 374            hash_map::Entry::Vacant(entry) => {
 375                let snapshot = buffer.read(cx).snapshot();
 376                let project_entity_id = project.entity_id();
 377                entry.insert(RegisteredBuffer {
 378                    snapshot,
 379                    _subscriptions: [
 380                        cx.subscribe(buffer, {
 381                            let project = project.downgrade();
 382                            move |this, buffer, event, cx| {
 383                                if let language::BufferEvent::Edited = event
 384                                    && let Some(project) = project.upgrade()
 385                                {
 386                                    this.report_changes_for_buffer(&buffer, &project, cx);
 387                                }
 388                            }
 389                        }),
 390                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 391                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 392                            else {
 393                                return;
 394                            };
 395                            zeta_project.registered_buffers.remove(&buffer_id);
 396                        }),
 397                    ],
 398                })
 399            }
 400        }
 401    }
 402
 403    fn request_completion_impl<F, R>(
 404        &mut self,
 405        project: &Entity<Project>,
 406        buffer: &Entity<Buffer>,
 407        cursor: language::Anchor,
 408        cx: &mut Context<Self>,
 409        perform_predict_edits: F,
 410    ) -> Task<Result<Option<EditPrediction>>>
 411    where
 412        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 413        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 414            + Send
 415            + 'static,
 416    {
 417        let buffer = buffer.clone();
 418        let buffer_snapshotted_at = Instant::now();
 419        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 420        let zeta = cx.entity();
 421        let client = self.client.clone();
 422        let llm_token = self.llm_token.clone();
 423        let app_version = AppVersion::global(cx);
 424
 425        let zeta_project = self.get_or_init_zeta_project(project, cx);
 426        let mut events = Vec::with_capacity(zeta_project.events.len());
 427        events.extend(zeta_project.events.iter().cloned());
 428        let events = Arc::new(events);
 429
 430        let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 431            let can_collect_file = self.can_collect_file(file, cx);
 432            let git_info = if can_collect_file {
 433                git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 434            } else {
 435                None
 436            };
 437            (git_info, can_collect_file)
 438        } else {
 439            (None, false)
 440        };
 441
 442        let full_path: Arc<Path> = snapshot
 443            .file()
 444            .map(|f| Arc::from(f.full_path(cx).as_path()))
 445            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 446        let full_path_str = full_path.to_string_lossy().into_owned();
 447        let cursor_point = cursor.to_point(&snapshot);
 448        let cursor_offset = cursor_point.to_offset(&snapshot);
 449        let prompt_for_events = {
 450            let events = events.clone();
 451            move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 452        };
 453        let gather_task = gather_context(
 454            full_path_str,
 455            &snapshot,
 456            cursor_point,
 457            prompt_for_events,
 458            cx,
 459        );
 460
 461        cx.spawn(async move |this, cx| {
 462            let GatherContextOutput {
 463                mut body,
 464                editable_range,
 465                included_events_count,
 466            } = gather_task.await?;
 467            let done_gathering_context_at = Instant::now();
 468
 469            let included_events = &events[events.len() - included_events_count..events.len()];
 470            body.can_collect_data = can_collect_file
 471                && this
 472                    .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
 473                    .unwrap_or(false);
 474            if body.can_collect_data {
 475                body.git_info = git_info;
 476            }
 477
 478            log::debug!(
 479                "Events:\n{}\nExcerpt:\n{:?}",
 480                body.input_events,
 481                body.input_excerpt
 482            );
 483
 484            let input_outline = body.outline.clone().unwrap_or_default();
 485            let input_events = body.input_events.clone();
 486            let input_excerpt = body.input_excerpt.clone();
 487
 488            let response = perform_predict_edits(PerformPredictEditsParams {
 489                client,
 490                llm_token,
 491                app_version,
 492                body,
 493            })
 494            .await;
 495            let (response, usage) = match response {
 496                Ok(response) => response,
 497                Err(err) => {
 498                    if err.is::<ZedUpdateRequiredError>() {
 499                        cx.update(|cx| {
 500                            zeta.update(cx, |zeta, _cx| {
 501                                zeta.update_required = true;
 502                            });
 503
 504                            let error_message: SharedString = err.to_string().into();
 505                            show_app_notification(
 506                                NotificationId::unique::<ZedUpdateRequiredError>(),
 507                                cx,
 508                                move |cx| {
 509                                    cx.new(|cx| {
 510                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 511                                            .with_link_button(
 512                                                "Update Zed",
 513                                                "https://zed.dev/releases",
 514                                            )
 515                                    })
 516                                },
 517                            );
 518                        })
 519                        .ok();
 520                    }
 521
 522                    return Err(err);
 523                }
 524            };
 525
 526            let received_response_at = Instant::now();
 527            log::debug!("completion response: {}", &response.output_excerpt);
 528
 529            if let Some(usage) = usage {
 530                this.update(cx, |this, cx| {
 531                    this.user_store.update(cx, |user_store, cx| {
 532                        user_store.update_edit_prediction_usage(usage, cx);
 533                    });
 534                })
 535                .ok();
 536            }
 537
 538            let edit_prediction = Self::process_completion_response(
 539                response,
 540                buffer,
 541                &snapshot,
 542                editable_range,
 543                cursor_offset,
 544                full_path,
 545                input_outline,
 546                input_events,
 547                input_excerpt,
 548                buffer_snapshotted_at,
 549                cx,
 550            )
 551            .await;
 552
 553            let finished_at = Instant::now();
 554
 555            // record latency for ~1% of requests
 556            if rand::random::<u8>() <= 2 {
 557                telemetry::event!(
 558                    "Edit Prediction Request",
 559                    context_latency = done_gathering_context_at
 560                        .duration_since(buffer_snapshotted_at)
 561                        .as_millis(),
 562                    request_latency = received_response_at
 563                        .duration_since(done_gathering_context_at)
 564                        .as_millis(),
 565                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 566                );
 567            }
 568
 569            edit_prediction
 570        })
 571    }
 572
 573    #[cfg(any(test, feature = "test-support"))]
 574    pub fn fake_completion(
 575        &mut self,
 576        project: &Entity<Project>,
 577        buffer: &Entity<Buffer>,
 578        position: language::Anchor,
 579        response: PredictEditsResponse,
 580        cx: &mut Context<Self>,
 581    ) -> Task<Result<Option<EditPrediction>>> {
 582        self.request_completion_impl(project, buffer, position, cx, |_params| {
 583            std::future::ready(Ok((response, None)))
 584        })
 585    }
 586
 587    pub fn request_completion(
 588        &mut self,
 589        project: &Entity<Project>,
 590        buffer: &Entity<Buffer>,
 591        position: language::Anchor,
 592        cx: &mut Context<Self>,
 593    ) -> Task<Result<Option<EditPrediction>>> {
 594        self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
 595    }
 596
 597    pub fn perform_predict_edits(
 598        params: PerformPredictEditsParams,
 599    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 600        async move {
 601            let PerformPredictEditsParams {
 602                client,
 603                llm_token,
 604                app_version,
 605                body,
 606                ..
 607            } = params;
 608
 609            let http_client = client.http_client();
 610            let mut token = llm_token.acquire(&client).await?;
 611            let mut did_retry = false;
 612
 613            loop {
 614                let request_builder = http_client::Request::builder().method(Method::POST);
 615                let request_builder =
 616                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 617                        request_builder.uri(predict_edits_url)
 618                    } else {
 619                        request_builder.uri(
 620                            http_client
 621                                .build_zed_llm_url("/predict_edits/v2", &[])?
 622                                .as_ref(),
 623                        )
 624                    };
 625                let request = request_builder
 626                    .header("Content-Type", "application/json")
 627                    .header("Authorization", format!("Bearer {}", token))
 628                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 629                    .body(serde_json::to_string(&body)?.into())?;
 630
 631                let mut response = http_client.send(request).await?;
 632
 633                if let Some(minimum_required_version) = response
 634                    .headers()
 635                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 636                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 637                {
 638                    anyhow::ensure!(
 639                        app_version >= minimum_required_version,
 640                        ZedUpdateRequiredError {
 641                            minimum_version: minimum_required_version
 642                        }
 643                    );
 644                }
 645
 646                if response.status().is_success() {
 647                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 648
 649                    let mut body = String::new();
 650                    response.body_mut().read_to_string(&mut body).await?;
 651                    return Ok((serde_json::from_str(&body)?, usage));
 652                } else if !did_retry
 653                    && response
 654                        .headers()
 655                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 656                        .is_some()
 657                {
 658                    did_retry = true;
 659                    token = llm_token.refresh(&client).await?;
 660                } else {
 661                    let mut body = String::new();
 662                    response.body_mut().read_to_string(&mut body).await?;
 663                    anyhow::bail!(
 664                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 665                        response.status(),
 666                        body
 667                    );
 668                }
 669            }
 670        }
 671    }
 672
 673    fn accept_edit_prediction(
 674        &mut self,
 675        request_id: EditPredictionId,
 676        cx: &mut Context<Self>,
 677    ) -> Task<Result<()>> {
 678        let client = self.client.clone();
 679        let llm_token = self.llm_token.clone();
 680        let app_version = AppVersion::global(cx);
 681        cx.spawn(async move |this, cx| {
 682            let http_client = client.http_client();
 683            let mut response = llm_token_retry(&llm_token, &client, |token| {
 684                let request_builder = http_client::Request::builder().method(Method::POST);
 685                let request_builder =
 686                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 687                        request_builder.uri(accept_prediction_url)
 688                    } else {
 689                        request_builder.uri(
 690                            http_client
 691                                .build_zed_llm_url("/predict_edits/accept", &[])?
 692                                .as_ref(),
 693                        )
 694                    };
 695                Ok(request_builder
 696                    .header("Content-Type", "application/json")
 697                    .header("Authorization", format!("Bearer {}", token))
 698                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 699                    .body(
 700                        serde_json::to_string(&AcceptEditPredictionBody {
 701                            request_id: request_id.0,
 702                        })?
 703                        .into(),
 704                    )?)
 705            })
 706            .await?;
 707
 708            if let Some(minimum_required_version) = response
 709                .headers()
 710                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 711                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 712                && app_version < minimum_required_version
 713            {
 714                return Err(anyhow!(ZedUpdateRequiredError {
 715                    minimum_version: minimum_required_version
 716                }));
 717            }
 718
 719            if response.status().is_success() {
 720                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 721                    this.update(cx, |this, cx| {
 722                        this.user_store.update(cx, |user_store, cx| {
 723                            user_store.update_edit_prediction_usage(usage, cx);
 724                        });
 725                    })?;
 726                }
 727
 728                Ok(())
 729            } else {
 730                let mut body = String::new();
 731                response.body_mut().read_to_string(&mut body).await?;
 732                Err(anyhow!(
 733                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 734                    response.status(),
 735                    body
 736                ))
 737            }
 738        })
 739    }
 740
 741    fn process_completion_response(
 742        prediction_response: PredictEditsResponse,
 743        buffer: Entity<Buffer>,
 744        snapshot: &BufferSnapshot,
 745        editable_range: Range<usize>,
 746        cursor_offset: usize,
 747        path: Arc<Path>,
 748        input_outline: String,
 749        input_events: String,
 750        input_excerpt: String,
 751        buffer_snapshotted_at: Instant,
 752        cx: &AsyncApp,
 753    ) -> Task<Result<Option<EditPrediction>>> {
 754        let snapshot = snapshot.clone();
 755        let request_id = prediction_response.request_id;
 756        let output_excerpt = prediction_response.output_excerpt;
 757        cx.spawn(async move |cx| {
 758            let output_excerpt: Arc<str> = output_excerpt.into();
 759
 760            let edits: Arc<[(Range<Anchor>, String)]> = cx
 761                .background_spawn({
 762                    let output_excerpt = output_excerpt.clone();
 763                    let editable_range = editable_range.clone();
 764                    let snapshot = snapshot.clone();
 765                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 766                })
 767                .await?
 768                .into();
 769
 770            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 771                let edits = edits.clone();
 772                |buffer, cx| {
 773                    let new_snapshot = buffer.snapshot();
 774                    let edits: Arc<[(Range<Anchor>, String)]> =
 775                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 776                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 777                }
 778            })?
 779            else {
 780                return anyhow::Ok(None);
 781            };
 782
 783            let edit_preview = edit_preview.await;
 784
 785            Ok(Some(EditPrediction {
 786                id: EditPredictionId(request_id),
 787                path,
 788                excerpt_range: editable_range,
 789                cursor_offset,
 790                edits,
 791                edit_preview,
 792                snapshot,
 793                input_outline: input_outline.into(),
 794                input_events: input_events.into(),
 795                input_excerpt: input_excerpt.into(),
 796                output_excerpt,
 797                buffer_snapshotted_at,
 798                response_received_at: Instant::now(),
 799            }))
 800        })
 801    }
 802
 803    fn parse_edits(
 804        output_excerpt: Arc<str>,
 805        editable_range: Range<usize>,
 806        snapshot: &BufferSnapshot,
 807    ) -> Result<Vec<(Range<Anchor>, String)>> {
 808        let content = output_excerpt.replace(CURSOR_MARKER, "");
 809
 810        let start_markers = content
 811            .match_indices(EDITABLE_REGION_START_MARKER)
 812            .collect::<Vec<_>>();
 813        anyhow::ensure!(
 814            start_markers.len() == 1,
 815            "expected exactly one start marker, found {}",
 816            start_markers.len()
 817        );
 818
 819        let end_markers = content
 820            .match_indices(EDITABLE_REGION_END_MARKER)
 821            .collect::<Vec<_>>();
 822        anyhow::ensure!(
 823            end_markers.len() == 1,
 824            "expected exactly one end marker, found {}",
 825            end_markers.len()
 826        );
 827
 828        let sof_markers = content
 829            .match_indices(START_OF_FILE_MARKER)
 830            .collect::<Vec<_>>();
 831        anyhow::ensure!(
 832            sof_markers.len() <= 1,
 833            "expected at most one start-of-file marker, found {}",
 834            sof_markers.len()
 835        );
 836
 837        let codefence_start = start_markers[0].0;
 838        let content = &content[codefence_start..];
 839
 840        let newline_ix = content.find('\n').context("could not find newline")?;
 841        let content = &content[newline_ix + 1..];
 842
 843        let codefence_end = content
 844            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 845            .context("could not find end marker")?;
 846        let new_text = &content[..codefence_end];
 847
 848        let old_text = snapshot
 849            .text_for_range(editable_range.clone())
 850            .collect::<String>();
 851
 852        Ok(Self::compute_edits(
 853            old_text,
 854            new_text,
 855            editable_range.start,
 856            snapshot,
 857        ))
 858    }
 859
 860    pub fn compute_edits(
 861        old_text: String,
 862        new_text: &str,
 863        offset: usize,
 864        snapshot: &BufferSnapshot,
 865    ) -> Vec<(Range<Anchor>, String)> {
 866        text_diff(&old_text, new_text)
 867            .into_iter()
 868            .map(|(mut old_range, new_text)| {
 869                old_range.start += offset;
 870                old_range.end += offset;
 871
 872                let prefix_len = common_prefix(
 873                    snapshot.chars_for_range(old_range.clone()),
 874                    new_text.chars(),
 875                );
 876                old_range.start += prefix_len;
 877
 878                let suffix_len = common_prefix(
 879                    snapshot.reversed_chars_for_range(old_range.clone()),
 880                    new_text[prefix_len..].chars().rev(),
 881                );
 882                old_range.end = old_range.end.saturating_sub(suffix_len);
 883
 884                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 885                let range = if old_range.is_empty() {
 886                    let anchor = snapshot.anchor_after(old_range.start);
 887                    anchor..anchor
 888                } else {
 889                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 890                };
 891                (range, new_text)
 892            })
 893            .collect()
 894    }
 895
 896    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 897        self.rated_completions.contains(&completion_id)
 898    }
 899
 900    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 901        self.shown_completions.push_front(completion.clone());
 902        if self.shown_completions.len() > 50 {
 903            let completion = self.shown_completions.pop_back().unwrap();
 904            self.rated_completions.remove(&completion.id);
 905        }
 906        cx.notify();
 907    }
 908
 909    pub fn rate_completion(
 910        &mut self,
 911        completion: &EditPrediction,
 912        rating: EditPredictionRating,
 913        feedback: String,
 914        cx: &mut Context<Self>,
 915    ) {
 916        self.rated_completions.insert(completion.id);
 917        telemetry::event!(
 918            "Edit Prediction Rated",
 919            rating,
 920            input_events = completion.input_events,
 921            input_excerpt = completion.input_excerpt,
 922            input_outline = completion.input_outline,
 923            output_excerpt = completion.output_excerpt,
 924            feedback
 925        );
 926        self.client.telemetry().flush_events().detach();
 927        cx.notify();
 928    }
 929
 930    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 931        self.shown_completions.iter()
 932    }
 933
 934    pub fn shown_completions_len(&self) -> usize {
 935        self.shown_completions.len()
 936    }
 937
 938    fn report_changes_for_buffer(
 939        &mut self,
 940        buffer: &Entity<Buffer>,
 941        project: &Entity<Project>,
 942        cx: &mut Context<Self>,
 943    ) -> BufferSnapshot {
 944        let zeta_project = self.get_or_init_zeta_project(project, cx);
 945        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 946
 947        let new_snapshot = buffer.read(cx).snapshot();
 948        if new_snapshot.version != registered_buffer.snapshot.version {
 949            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 950            Self::push_event(
 951                zeta_project,
 952                Event::BufferChange {
 953                    old_snapshot,
 954                    new_snapshot: new_snapshot.clone(),
 955                    timestamp: Instant::now(),
 956                },
 957            );
 958        }
 959
 960        new_snapshot
 961    }
 962
 963    fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
 964        self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
 965    }
 966
 967    fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
 968        if !self.data_collection_choice.is_enabled() {
 969            return false;
 970        }
 971        let mut last_checked_file = None;
 972        for event in events {
 973            match event {
 974                Event::BufferChange {
 975                    old_snapshot,
 976                    new_snapshot,
 977                    ..
 978                } => {
 979                    if let Some(old_file) = old_snapshot.file()
 980                        && let Some(new_file) = new_snapshot.file()
 981                    {
 982                        if let Some(last_checked_file) = last_checked_file
 983                            && Arc::ptr_eq(last_checked_file, old_file)
 984                            && Arc::ptr_eq(last_checked_file, new_file)
 985                        {
 986                            continue;
 987                        }
 988                        if !self.can_collect_file(old_file, cx) {
 989                            return false;
 990                        }
 991                        if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
 992                        {
 993                            return false;
 994                        }
 995                        last_checked_file = Some(new_file);
 996                    } else {
 997                        return false;
 998                    }
 999                }
1000            }
1001        }
1002        true
1003    }
1004
1005    fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1006        if !file.is_local() || file.is_private() {
1007            return false;
1008        }
1009        self.license_detection_watchers
1010            .get(&file.worktree_id(cx))
1011            .is_some_and(|watcher| watcher.is_project_open_source())
1012    }
1013
1014    fn load_data_collection_choice() -> DataCollectionChoice {
1015        let choice = KEY_VALUE_STORE
1016            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1017            .log_err()
1018            .flatten();
1019
1020        match choice.as_deref() {
1021            Some("true") => DataCollectionChoice::Enabled,
1022            Some("false") => DataCollectionChoice::Disabled,
1023            Some(_) => {
1024                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1025                DataCollectionChoice::NotAnswered
1026            }
1027            None => DataCollectionChoice::NotAnswered,
1028        }
1029    }
1030
1031    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1032        self.data_collection_choice = self.data_collection_choice.toggle();
1033        let new_choice = self.data_collection_choice;
1034        db::write_and_log(cx, move || {
1035            KEY_VALUE_STORE.write_kvp(
1036                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1037                new_choice.is_enabled().to_string(),
1038            )
1039        });
1040    }
1041}
1042
1043pub struct PerformPredictEditsParams {
1044    pub client: Arc<Client>,
1045    pub llm_token: LlmApiToken,
1046    pub app_version: SemanticVersion,
1047    pub body: PredictEditsBody,
1048}
1049
1050#[derive(Error, Debug)]
1051#[error(
1052    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1053)]
1054pub struct ZedUpdateRequiredError {
1055    minimum_version: SemanticVersion,
1056}
1057
1058fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1059    a.zip(b)
1060        .take_while(|(a, b)| a == b)
1061        .map(|(a, _)| a.len_utf8())
1062        .sum()
1063}
1064
1065fn git_info_for_file(
1066    project: &Entity<Project>,
1067    project_path: &ProjectPath,
1068    cx: &App,
1069) -> Option<PredictEditsGitInfo> {
1070    let git_store = project.read(cx).git_store().read(cx);
1071    if let Some((repository, _repo_path)) =
1072        git_store.repository_and_path_for_project_path(project_path, cx)
1073    {
1074        let repository = repository.read(cx);
1075        let head_sha = repository
1076            .head_commit
1077            .as_ref()
1078            .map(|head_commit| head_commit.sha.to_string());
1079        let remote_origin_url = repository.remote_origin_url.clone();
1080        let remote_upstream_url = repository.remote_upstream_url.clone();
1081        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1082            return None;
1083        }
1084        Some(PredictEditsGitInfo {
1085            head_sha,
1086            remote_origin_url,
1087            remote_upstream_url,
1088        })
1089    } else {
1090        None
1091    }
1092}
1093
1094pub struct GatherContextOutput {
1095    pub body: PredictEditsBody,
1096    pub editable_range: Range<usize>,
1097    pub included_events_count: usize,
1098}
1099
1100pub fn gather_context(
1101    full_path_str: String,
1102    snapshot: &BufferSnapshot,
1103    cursor_point: language::Point,
1104    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1105    cx: &App,
1106) -> Task<Result<GatherContextOutput>> {
1107    cx.background_spawn({
1108        let snapshot = snapshot.clone();
1109        async move {
1110            let input_excerpt = excerpt_for_cursor_position(
1111                cursor_point,
1112                &full_path_str,
1113                &snapshot,
1114                MAX_REWRITE_TOKENS,
1115                MAX_CONTEXT_TOKENS,
1116            );
1117            let (input_events, included_events_count) = prompt_for_events();
1118            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1119
1120            let body = PredictEditsBody {
1121                input_events,
1122                input_excerpt: input_excerpt.prompt,
1123                can_collect_data: false,
1124                diagnostic_groups: None,
1125                git_info: None,
1126                outline: None,
1127                speculated_output: None,
1128            };
1129
1130            Ok(GatherContextOutput {
1131                body,
1132                editable_range,
1133                included_events_count,
1134            })
1135        }
1136    })
1137}
1138
1139fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1140    let mut result = String::new();
1141    for (ix, event) in events.iter().rev().enumerate() {
1142        let event_string = event.to_prompt();
1143        let event_tokens = guess_token_count(event_string.len());
1144        if event_tokens > remaining_tokens {
1145            return (result, ix);
1146        }
1147
1148        if !result.is_empty() {
1149            result.insert_str(0, "\n\n");
1150        }
1151        result.insert_str(0, &event_string);
1152        remaining_tokens -= event_tokens;
1153    }
1154    return (result, events.len());
1155}
1156
1157struct RegisteredBuffer {
1158    snapshot: BufferSnapshot,
1159    _subscriptions: [gpui::Subscription; 2],
1160}
1161
1162#[derive(Clone)]
1163pub enum Event {
1164    BufferChange {
1165        old_snapshot: BufferSnapshot,
1166        new_snapshot: BufferSnapshot,
1167        timestamp: Instant,
1168    },
1169}
1170
1171impl Event {
1172    fn to_prompt(&self) -> String {
1173        match self {
1174            Event::BufferChange {
1175                old_snapshot,
1176                new_snapshot,
1177                ..
1178            } => {
1179                let mut prompt = String::new();
1180
1181                let old_path = old_snapshot
1182                    .file()
1183                    .map(|f| f.path().as_ref())
1184                    .unwrap_or(RelPath::unix("untitled").unwrap());
1185                let new_path = new_snapshot
1186                    .file()
1187                    .map(|f| f.path().as_ref())
1188                    .unwrap_or(RelPath::unix("untitled").unwrap());
1189                if old_path != new_path {
1190                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1191                }
1192
1193                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1194                if !diff.is_empty() {
1195                    write!(
1196                        prompt,
1197                        "User edited {:?}:\n```diff\n{}\n```",
1198                        new_path, diff
1199                    )
1200                    .unwrap();
1201                }
1202
1203                prompt
1204            }
1205        }
1206    }
1207}
1208
1209#[derive(Debug, Clone)]
1210struct CurrentEditPrediction {
1211    buffer_id: EntityId,
1212    completion: EditPrediction,
1213}
1214
1215impl CurrentEditPrediction {
1216    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1217        if self.buffer_id != old_completion.buffer_id {
1218            return true;
1219        }
1220
1221        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1222            return true;
1223        };
1224        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1225            return false;
1226        };
1227
1228        if old_edits.len() == 1 && new_edits.len() == 1 {
1229            let (old_range, old_text) = &old_edits[0];
1230            let (new_range, new_text) = &new_edits[0];
1231            new_range == old_range && new_text.starts_with(old_text)
1232        } else {
1233            true
1234        }
1235    }
1236}
1237
1238struct PendingCompletion {
1239    id: usize,
1240    _task: Task<()>,
1241}
1242
1243#[derive(Debug, Clone, Copy)]
1244pub enum DataCollectionChoice {
1245    NotAnswered,
1246    Enabled,
1247    Disabled,
1248}
1249
1250impl DataCollectionChoice {
1251    pub fn is_enabled(self) -> bool {
1252        match self {
1253            Self::Enabled => true,
1254            Self::NotAnswered | Self::Disabled => false,
1255        }
1256    }
1257
1258    pub fn is_answered(self) -> bool {
1259        match self {
1260            Self::Enabled | Self::Disabled => true,
1261            Self::NotAnswered => false,
1262        }
1263    }
1264
1265    #[must_use]
1266    pub fn toggle(&self) -> DataCollectionChoice {
1267        match self {
1268            Self::Enabled => Self::Disabled,
1269            Self::Disabled => Self::Enabled,
1270            Self::NotAnswered => Self::Enabled,
1271        }
1272    }
1273}
1274
1275impl From<bool> for DataCollectionChoice {
1276    fn from(value: bool) -> Self {
1277        match value {
1278            true => DataCollectionChoice::Enabled,
1279            false => DataCollectionChoice::Disabled,
1280        }
1281    }
1282}
1283
1284async fn llm_token_retry(
1285    llm_token: &LlmApiToken,
1286    client: &Arc<Client>,
1287    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1288) -> Result<Response<AsyncBody>> {
1289    let mut did_retry = false;
1290    let http_client = client.http_client();
1291    let mut token = llm_token.acquire(client).await?;
1292    loop {
1293        let request = build_request(token.clone())?;
1294        let response = http_client.send(request).await?;
1295
1296        if !did_retry
1297            && !response.status().is_success()
1298            && response
1299                .headers()
1300                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1301                .is_some()
1302        {
1303            did_retry = true;
1304            token = llm_token.refresh(client).await?;
1305            continue;
1306        }
1307
1308        return Ok(response);
1309    }
1310}
1311
1312pub struct ZetaEditPredictionProvider {
1313    zeta: Entity<Zeta>,
1314    singleton_buffer: Option<Entity<Buffer>>,
1315    pending_completions: ArrayVec<PendingCompletion, 2>,
1316    next_pending_completion_id: usize,
1317    current_completion: Option<CurrentEditPrediction>,
1318    last_request_timestamp: Instant,
1319    project: Entity<Project>,
1320}
1321
1322impl ZetaEditPredictionProvider {
1323    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1324
1325    pub fn new(
1326        zeta: Entity<Zeta>,
1327        project: Entity<Project>,
1328        singleton_buffer: Option<Entity<Buffer>>,
1329    ) -> Self {
1330        Self {
1331            zeta,
1332            singleton_buffer,
1333            pending_completions: ArrayVec::new(),
1334            next_pending_completion_id: 0,
1335            current_completion: None,
1336            last_request_timestamp: Instant::now(),
1337            project,
1338        }
1339    }
1340}
1341
1342impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1343    fn name() -> &'static str {
1344        "zed-predict"
1345    }
1346
1347    fn display_name() -> &'static str {
1348        "Zed's Edit Predictions"
1349    }
1350
1351    fn show_completions_in_menu() -> bool {
1352        true
1353    }
1354
1355    fn show_tab_accept_marker() -> bool {
1356        true
1357    }
1358
1359    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1360        if let Some(buffer) = &self.singleton_buffer
1361            && let Some(file) = buffer.read(cx).file()
1362        {
1363            let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1364            if self.zeta.read(cx).data_collection_choice.is_enabled() {
1365                DataCollectionState::Enabled {
1366                    is_project_open_source,
1367                }
1368            } else {
1369                DataCollectionState::Disabled {
1370                    is_project_open_source,
1371                }
1372            }
1373        } else {
1374            return DataCollectionState::Disabled {
1375                is_project_open_source: false,
1376            };
1377        }
1378    }
1379
1380    fn toggle_data_collection(&mut self, cx: &mut App) {
1381        self.zeta
1382            .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1383    }
1384
1385    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1386        self.zeta.read(cx).usage(cx)
1387    }
1388
1389    fn is_enabled(
1390        &self,
1391        _buffer: &Entity<Buffer>,
1392        _cursor_position: language::Anchor,
1393        _cx: &App,
1394    ) -> bool {
1395        true
1396    }
1397    fn is_refreshing(&self) -> bool {
1398        !self.pending_completions.is_empty()
1399    }
1400
1401    fn refresh(
1402        &mut self,
1403        buffer: Entity<Buffer>,
1404        position: language::Anchor,
1405        _debounce: bool,
1406        cx: &mut Context<Self>,
1407    ) {
1408        if self.zeta.read(cx).update_required {
1409            return;
1410        }
1411
1412        if self
1413            .zeta
1414            .read(cx)
1415            .user_store
1416            .read_with(cx, |user_store, _cx| {
1417                user_store.account_too_young() || user_store.has_overdue_invoices()
1418            })
1419        {
1420            return;
1421        }
1422
1423        if let Some(current_completion) = self.current_completion.as_ref() {
1424            let snapshot = buffer.read(cx).snapshot();
1425            if current_completion
1426                .completion
1427                .interpolate(&snapshot)
1428                .is_some()
1429            {
1430                return;
1431            }
1432        }
1433
1434        let pending_completion_id = self.next_pending_completion_id;
1435        self.next_pending_completion_id += 1;
1436        let last_request_timestamp = self.last_request_timestamp;
1437
1438        let project = self.project.clone();
1439        let task = cx.spawn(async move |this, cx| {
1440            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1441                .checked_duration_since(Instant::now())
1442            {
1443                cx.background_executor().timer(timeout).await;
1444            }
1445
1446            let completion_request = this.update(cx, |this, cx| {
1447                this.last_request_timestamp = Instant::now();
1448                this.zeta.update(cx, |zeta, cx| {
1449                    zeta.request_completion(&project, &buffer, position, cx)
1450                })
1451            });
1452
1453            let completion = match completion_request {
1454                Ok(completion_request) => {
1455                    let completion_request = completion_request.await;
1456                    completion_request.map(|c| {
1457                        c.map(|completion| CurrentEditPrediction {
1458                            buffer_id: buffer.entity_id(),
1459                            completion,
1460                        })
1461                    })
1462                }
1463                Err(error) => Err(error),
1464            };
1465            let Some(new_completion) = completion
1466                .context("edit prediction failed")
1467                .log_err()
1468                .flatten()
1469            else {
1470                this.update(cx, |this, cx| {
1471                    if this.pending_completions[0].id == pending_completion_id {
1472                        this.pending_completions.remove(0);
1473                    } else {
1474                        this.pending_completions.clear();
1475                    }
1476
1477                    cx.notify();
1478                })
1479                .ok();
1480                return;
1481            };
1482
1483            this.update(cx, |this, cx| {
1484                if this.pending_completions[0].id == pending_completion_id {
1485                    this.pending_completions.remove(0);
1486                } else {
1487                    this.pending_completions.clear();
1488                }
1489
1490                if let Some(old_completion) = this.current_completion.as_ref() {
1491                    let snapshot = buffer.read(cx).snapshot();
1492                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1493                        this.zeta.update(cx, |zeta, cx| {
1494                            zeta.completion_shown(&new_completion.completion, cx);
1495                        });
1496                        this.current_completion = Some(new_completion);
1497                    }
1498                } else {
1499                    this.zeta.update(cx, |zeta, cx| {
1500                        zeta.completion_shown(&new_completion.completion, cx);
1501                    });
1502                    this.current_completion = Some(new_completion);
1503                }
1504
1505                cx.notify();
1506            })
1507            .ok();
1508        });
1509
1510        // We always maintain at most two pending completions. When we already
1511        // have two, we replace the newest one.
1512        if self.pending_completions.len() <= 1 {
1513            self.pending_completions.push(PendingCompletion {
1514                id: pending_completion_id,
1515                _task: task,
1516            });
1517        } else if self.pending_completions.len() == 2 {
1518            self.pending_completions.pop();
1519            self.pending_completions.push(PendingCompletion {
1520                id: pending_completion_id,
1521                _task: task,
1522            });
1523        }
1524    }
1525
1526    fn cycle(
1527        &mut self,
1528        _buffer: Entity<Buffer>,
1529        _cursor_position: language::Anchor,
1530        _direction: edit_prediction::Direction,
1531        _cx: &mut Context<Self>,
1532    ) {
1533        // Right now we don't support cycling.
1534    }
1535
1536    fn accept(&mut self, cx: &mut Context<Self>) {
1537        let completion_id = self
1538            .current_completion
1539            .as_ref()
1540            .map(|completion| completion.completion.id);
1541        if let Some(completion_id) = completion_id {
1542            self.zeta
1543                .update(cx, |zeta, cx| {
1544                    zeta.accept_edit_prediction(completion_id, cx)
1545                })
1546                .detach();
1547        }
1548        self.pending_completions.clear();
1549    }
1550
1551    fn discard(&mut self, _cx: &mut Context<Self>) {
1552        self.pending_completions.clear();
1553        self.current_completion.take();
1554    }
1555
1556    fn suggest(
1557        &mut self,
1558        buffer: &Entity<Buffer>,
1559        cursor_position: language::Anchor,
1560        cx: &mut Context<Self>,
1561    ) -> Option<edit_prediction::EditPrediction> {
1562        let CurrentEditPrediction {
1563            buffer_id,
1564            completion,
1565            ..
1566        } = self.current_completion.as_mut()?;
1567
1568        // Invalidate previous completion if it was generated for a different buffer.
1569        if *buffer_id != buffer.entity_id() {
1570            self.current_completion.take();
1571            return None;
1572        }
1573
1574        let buffer = buffer.read(cx);
1575        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1576            self.current_completion.take();
1577            return None;
1578        };
1579
1580        let cursor_row = cursor_position.to_point(buffer).row;
1581        let (closest_edit_ix, (closest_edit_range, _)) =
1582            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1583                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1584                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1585                cmp::min(distance_from_start, distance_from_end)
1586            })?;
1587
1588        let mut edit_start_ix = closest_edit_ix;
1589        for (range, _) in edits[..edit_start_ix].iter().rev() {
1590            let distance_from_closest_edit =
1591                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1592            if distance_from_closest_edit <= 1 {
1593                edit_start_ix -= 1;
1594            } else {
1595                break;
1596            }
1597        }
1598
1599        let mut edit_end_ix = closest_edit_ix + 1;
1600        for (range, _) in &edits[edit_end_ix..] {
1601            let distance_from_closest_edit =
1602                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1603            if distance_from_closest_edit <= 1 {
1604                edit_end_ix += 1;
1605            } else {
1606                break;
1607            }
1608        }
1609
1610        Some(edit_prediction::EditPrediction::Local {
1611            id: Some(completion.id.to_string().into()),
1612            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1613            edit_preview: Some(completion.edit_preview.clone()),
1614        })
1615    }
1616}
1617
1618/// Typical number of string bytes per token for the purposes of limiting model input. This is
1619/// intentionally low to err on the side of underestimating limits.
1620const BYTES_PER_TOKEN_GUESS: usize = 3;
1621
1622fn guess_token_count(bytes: usize) -> usize {
1623    bytes / BYTES_PER_TOKEN_GUESS
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628    use client::test::FakeServer;
1629    use clock::FakeSystemClock;
1630    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1631    use gpui::TestAppContext;
1632    use http_client::FakeHttpClient;
1633    use indoc::indoc;
1634    use language::Point;
1635    use parking_lot::Mutex;
1636    use serde_json::json;
1637    use settings::SettingsStore;
1638    use util::{path, rel_path::rel_path};
1639
1640    use super::*;
1641
1642    const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1643
1644    #[gpui::test]
1645    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1646        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1647        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1648            to_completion_edits(
1649                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1650                &buffer,
1651                cx,
1652            )
1653            .into()
1654        });
1655
1656        let edit_preview = cx
1657            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1658            .await;
1659
1660        let completion = EditPrediction {
1661            edits,
1662            edit_preview,
1663            path: Path::new("").into(),
1664            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1665            id: EditPredictionId(Uuid::new_v4()),
1666            excerpt_range: 0..0,
1667            cursor_offset: 0,
1668            input_outline: "".into(),
1669            input_events: "".into(),
1670            input_excerpt: "".into(),
1671            output_excerpt: "".into(),
1672            buffer_snapshotted_at: Instant::now(),
1673            response_received_at: Instant::now(),
1674        };
1675
1676        cx.update(|cx| {
1677            assert_eq!(
1678                from_completion_edits(
1679                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1680                    &buffer,
1681                    cx
1682                ),
1683                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1684            );
1685
1686            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1687            assert_eq!(
1688                from_completion_edits(
1689                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1690                    &buffer,
1691                    cx
1692                ),
1693                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1694            );
1695
1696            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1697            assert_eq!(
1698                from_completion_edits(
1699                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1700                    &buffer,
1701                    cx
1702                ),
1703                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1704            );
1705
1706            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1707            assert_eq!(
1708                from_completion_edits(
1709                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1710                    &buffer,
1711                    cx
1712                ),
1713                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1714            );
1715
1716            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1717            assert_eq!(
1718                from_completion_edits(
1719                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1720                    &buffer,
1721                    cx
1722                ),
1723                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1724            );
1725
1726            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1727            assert_eq!(
1728                from_completion_edits(
1729                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1730                    &buffer,
1731                    cx
1732                ),
1733                vec![(9..11, "".to_string())]
1734            );
1735
1736            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1737            assert_eq!(
1738                from_completion_edits(
1739                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1740                    &buffer,
1741                    cx
1742                ),
1743                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1744            );
1745
1746            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1747            assert_eq!(
1748                from_completion_edits(
1749                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1750                    &buffer,
1751                    cx
1752                ),
1753                vec![(4..4, "M".to_string())]
1754            );
1755
1756            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1757            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1758        })
1759    }
1760
1761    #[gpui::test]
1762    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1763        init_test(cx);
1764
1765        assert_eq!(
1766            apply_edit_prediction(
1767                indoc! {"
1768                    fn main() {
1769                        let word_1 = \"lorem\";
1770                        let range = word.len()..word.len();
1771                    }
1772                "},
1773                indoc! {"
1774                    <|editable_region_start|>
1775                    fn main() {
1776                        let word_1 = \"lorem\";
1777                        let range = word_1.len()..word_1.len();
1778                    }
1779
1780                    <|editable_region_end|>
1781                "},
1782                cx,
1783            )
1784            .await,
1785            indoc! {"
1786                fn main() {
1787                    let word_1 = \"lorem\";
1788                    let range = word_1.len()..word_1.len();
1789                }
1790            "},
1791        );
1792
1793        assert_eq!(
1794            apply_edit_prediction(
1795                indoc! {"
1796                    fn main() {
1797                        let story = \"the quick\"
1798                    }
1799                "},
1800                indoc! {"
1801                    <|editable_region_start|>
1802                    fn main() {
1803                        let story = \"the quick brown fox jumps over the lazy dog\";
1804                    }
1805
1806                    <|editable_region_end|>
1807                "},
1808                cx,
1809            )
1810            .await,
1811            indoc! {"
1812                fn main() {
1813                    let story = \"the quick brown fox jumps over the lazy dog\";
1814                }
1815            "},
1816        );
1817    }
1818
1819    #[gpui::test]
1820    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1821        init_test(cx);
1822
1823        let buffer_content = "lorem\n";
1824        let completion_response = indoc! {"
1825            ```animals.js
1826            <|start_of_file|>
1827            <|editable_region_start|>
1828            lorem
1829            ipsum
1830            <|editable_region_end|>
1831            ```"};
1832
1833        assert_eq!(
1834            apply_edit_prediction(buffer_content, completion_response, cx).await,
1835            "lorem\nipsum"
1836        );
1837    }
1838
1839    #[gpui::test]
1840    async fn test_can_collect_data(cx: &mut TestAppContext) {
1841        init_test(cx);
1842
1843        let fs = project::FakeFs::new(cx.executor());
1844        fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1845            .await;
1846
1847        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1848        let buffer = project
1849            .update(cx, |project, cx| {
1850                project.open_local_buffer(path!("/project/src/main.rs"), cx)
1851            })
1852            .await
1853            .unwrap();
1854
1855        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1856        zeta.update(cx, |zeta, _cx| {
1857            zeta.data_collection_choice = DataCollectionChoice::Enabled
1858        });
1859
1860        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1861        assert_eq!(
1862            captured_request.lock().clone().unwrap().can_collect_data,
1863            true
1864        );
1865
1866        zeta.update(cx, |zeta, _cx| {
1867            zeta.data_collection_choice = DataCollectionChoice::Disabled
1868        });
1869
1870        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1871        assert_eq!(
1872            captured_request.lock().clone().unwrap().can_collect_data,
1873            false
1874        );
1875    }
1876
1877    #[gpui::test]
1878    async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1879        init_test(cx);
1880
1881        let fs = project::FakeFs::new(cx.executor());
1882        let project = Project::test(fs.clone(), [], cx).await;
1883
1884        let buffer = cx.new(|_cx| {
1885            Buffer::remote(
1886                language::BufferId::new(1).unwrap(),
1887                1,
1888                language::Capability::ReadWrite,
1889                "fn main() {\n    println!(\"Hello\");\n}",
1890            )
1891        });
1892
1893        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1894        zeta.update(cx, |zeta, _cx| {
1895            zeta.data_collection_choice = DataCollectionChoice::Enabled
1896        });
1897
1898        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1899        assert_eq!(
1900            captured_request.lock().clone().unwrap().can_collect_data,
1901            false
1902        );
1903    }
1904
1905    #[gpui::test]
1906    async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1907        init_test(cx);
1908
1909        let fs = project::FakeFs::new(cx.executor());
1910        fs.insert_tree(
1911            path!("/project"),
1912            json!({
1913                "LICENSE": BSD_0_TXT,
1914                ".env": "SECRET_KEY=secret"
1915            }),
1916        )
1917        .await;
1918
1919        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1920        let buffer = project
1921            .update(cx, |project, cx| {
1922                project.open_local_buffer("/project/.env", cx)
1923            })
1924            .await
1925            .unwrap();
1926
1927        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1928        zeta.update(cx, |zeta, _cx| {
1929            zeta.data_collection_choice = DataCollectionChoice::Enabled
1930        });
1931
1932        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1933        assert_eq!(
1934            captured_request.lock().clone().unwrap().can_collect_data,
1935            false
1936        );
1937    }
1938
1939    #[gpui::test]
1940    async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1941        init_test(cx);
1942
1943        let fs = project::FakeFs::new(cx.executor());
1944        let project = Project::test(fs.clone(), [], cx).await;
1945        let buffer = cx.new(|cx| Buffer::local("", cx));
1946
1947        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1948        zeta.update(cx, |zeta, _cx| {
1949            zeta.data_collection_choice = DataCollectionChoice::Enabled
1950        });
1951
1952        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1953        assert_eq!(
1954            captured_request.lock().clone().unwrap().can_collect_data,
1955            false
1956        );
1957    }
1958
1959    #[gpui::test]
1960    async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1961        init_test(cx);
1962
1963        let fs = project::FakeFs::new(cx.executor());
1964        fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1965            .await;
1966
1967        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1968        let buffer = project
1969            .update(cx, |project, cx| {
1970                project.open_local_buffer("/project/main.rs", cx)
1971            })
1972            .await
1973            .unwrap();
1974
1975        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1976        zeta.update(cx, |zeta, _cx| {
1977            zeta.data_collection_choice = DataCollectionChoice::Enabled
1978        });
1979
1980        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1981        assert_eq!(
1982            captured_request.lock().clone().unwrap().can_collect_data,
1983            false
1984        );
1985    }
1986
1987    #[gpui::test]
1988    async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1989        init_test(cx);
1990
1991        let fs = project::FakeFs::new(cx.executor());
1992        fs.insert_tree(
1993            path!("/open_source_worktree"),
1994            json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1995        )
1996        .await;
1997        fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1998            .await;
1999
2000        let project = Project::test(
2001            fs.clone(),
2002            [
2003                path!("/open_source_worktree").as_ref(),
2004                path!("/closed_source_worktree").as_ref(),
2005            ],
2006            cx,
2007        )
2008        .await;
2009        let buffer = project
2010            .update(cx, |project, cx| {
2011                project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2012            })
2013            .await
2014            .unwrap();
2015
2016        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2017        zeta.update(cx, |zeta, _cx| {
2018            zeta.data_collection_choice = DataCollectionChoice::Enabled
2019        });
2020
2021        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2022        assert_eq!(
2023            captured_request.lock().clone().unwrap().can_collect_data,
2024            true
2025        );
2026
2027        let closed_source_file = project
2028            .update(cx, |project, cx| {
2029                let worktree2 = project
2030                    .worktree_for_root_name("closed_source_worktree", cx)
2031                    .unwrap();
2032                worktree2.update(cx, |worktree2, cx| {
2033                    worktree2.load_file(rel_path("main.rs"), cx)
2034                })
2035            })
2036            .await
2037            .unwrap()
2038            .file;
2039
2040        buffer.update(cx, |buffer, cx| {
2041            buffer.file_updated(closed_source_file, cx);
2042        });
2043
2044        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2045        assert_eq!(
2046            captured_request.lock().clone().unwrap().can_collect_data,
2047            false
2048        );
2049    }
2050
2051    #[gpui::test]
2052    async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2053        init_test(cx);
2054
2055        let fs = project::FakeFs::new(cx.executor());
2056        fs.insert_tree(
2057            path!("/worktree1"),
2058            json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2059        )
2060        .await;
2061        fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2062            .await;
2063
2064        let project = Project::test(
2065            fs.clone(),
2066            [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2067            cx,
2068        )
2069        .await;
2070        let buffer = project
2071            .update(cx, |project, cx| {
2072                project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2073            })
2074            .await
2075            .unwrap();
2076        let private_buffer = project
2077            .update(cx, |project, cx| {
2078                project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2079            })
2080            .await
2081            .unwrap();
2082
2083        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2084        zeta.update(cx, |zeta, _cx| {
2085            zeta.data_collection_choice = DataCollectionChoice::Enabled
2086        });
2087
2088        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2089        assert_eq!(
2090            captured_request.lock().clone().unwrap().can_collect_data,
2091            true
2092        );
2093
2094        // this has a side effect of registering the buffer to watch for edits
2095        run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2096        assert_eq!(
2097            captured_request.lock().clone().unwrap().can_collect_data,
2098            false
2099        );
2100
2101        private_buffer.update(cx, |private_buffer, cx| {
2102            private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2103        });
2104
2105        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2106        assert_eq!(
2107            captured_request.lock().clone().unwrap().can_collect_data,
2108            false
2109        );
2110
2111        // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2112        // included
2113        buffer.update(cx, |buffer, cx| {
2114            buffer.edit(
2115                [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2116                None,
2117                cx,
2118            );
2119        });
2120
2121        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2122        assert_eq!(
2123            captured_request.lock().clone().unwrap().can_collect_data,
2124            true
2125        );
2126    }
2127
2128    fn init_test(cx: &mut TestAppContext) {
2129        cx.update(|cx| {
2130            let settings_store = SettingsStore::test(cx);
2131            cx.set_global(settings_store);
2132            language::init(cx);
2133            client::init_settings(cx);
2134            Project::init_settings(cx);
2135        });
2136    }
2137
2138    async fn apply_edit_prediction(
2139        buffer_content: &str,
2140        completion_response: &str,
2141        cx: &mut TestAppContext,
2142    ) -> String {
2143        let fs = project::FakeFs::new(cx.executor());
2144        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2145        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2146        let (zeta, _, response) = make_test_zeta(&project, cx).await;
2147        *response.lock() = completion_response.to_string();
2148        let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2149        buffer.update(cx, |buffer, cx| {
2150            buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2151        });
2152        buffer.read_with(cx, |buffer, _| buffer.text())
2153    }
2154
2155    async fn run_edit_prediction(
2156        buffer: &Entity<Buffer>,
2157        project: &Entity<Project>,
2158        zeta: &Entity<Zeta>,
2159        cx: &mut TestAppContext,
2160    ) -> EditPrediction {
2161        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2162        zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2163        cx.background_executor.run_until_parked();
2164        let completion_task = zeta.update(cx, |zeta, cx| {
2165            zeta.request_completion(&project, buffer, cursor, cx)
2166        });
2167        completion_task.await.unwrap().unwrap()
2168    }
2169
2170    async fn make_test_zeta(
2171        project: &Entity<Project>,
2172        cx: &mut TestAppContext,
2173    ) -> (
2174        Entity<Zeta>,
2175        Arc<Mutex<Option<PredictEditsBody>>>,
2176        Arc<Mutex<String>>,
2177    ) {
2178        let default_response = indoc! {"
2179            ```main.rs
2180            <|start_of_file|>
2181            <|editable_region_start|>
2182            hello world
2183            <|editable_region_end|>
2184            ```"
2185        };
2186        let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2187        let completion_response: Arc<Mutex<String>> =
2188            Arc::new(Mutex::new(default_response.to_string()));
2189        let http_client = FakeHttpClient::create({
2190            let captured_request = captured_request.clone();
2191            let completion_response = completion_response.clone();
2192            move |req| {
2193                let captured_request = captured_request.clone();
2194                let completion_response = completion_response.clone();
2195                async move {
2196                    match (req.method(), req.uri().path()) {
2197                        (&Method::POST, "/client/llm_tokens") => {
2198                            Ok(http_client::Response::builder()
2199                                .status(200)
2200                                .body(
2201                                    serde_json::to_string(&CreateLlmTokenResponse {
2202                                        token: LlmToken("the-llm-token".to_string()),
2203                                    })
2204                                    .unwrap()
2205                                    .into(),
2206                                )
2207                                .unwrap())
2208                        }
2209                        (&Method::POST, "/predict_edits/v2") => {
2210                            let mut request_body = String::new();
2211                            req.into_body().read_to_string(&mut request_body).await?;
2212                            *captured_request.lock() =
2213                                Some(serde_json::from_str(&request_body).unwrap());
2214                            Ok(http_client::Response::builder()
2215                                .status(200)
2216                                .body(
2217                                    serde_json::to_string(&PredictEditsResponse {
2218                                        request_id: Uuid::new_v4(),
2219                                        output_excerpt: completion_response.lock().clone(),
2220                                    })
2221                                    .unwrap()
2222                                    .into(),
2223                                )
2224                                .unwrap())
2225                        }
2226                        _ => Ok(http_client::Response::builder()
2227                            .status(404)
2228                            .body("Not Found".into())
2229                            .unwrap()),
2230                    }
2231                }
2232            }
2233        });
2234
2235        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2236        cx.update(|cx| {
2237            RefreshLlmTokenListener::register(client.clone(), cx);
2238        });
2239        let _server = FakeServer::for_client(42, &client, cx).await;
2240
2241        let zeta = cx.new(|cx| {
2242            let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2243
2244            let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2245            for worktree in worktrees {
2246                let worktree_id = worktree.read(cx).id();
2247                zeta.license_detection_watchers
2248                    .entry(worktree_id)
2249                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2250            }
2251
2252            zeta
2253        });
2254
2255        (zeta, captured_request, completion_response)
2256    }
2257
2258    fn to_completion_edits(
2259        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2260        buffer: &Entity<Buffer>,
2261        cx: &App,
2262    ) -> Vec<(Range<Anchor>, String)> {
2263        let buffer = buffer.read(cx);
2264        iterator
2265            .into_iter()
2266            .map(|(range, text)| {
2267                (
2268                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2269                    text,
2270                )
2271            })
2272            .collect()
2273    }
2274
2275    fn from_completion_edits(
2276        editor_edits: &[(Range<Anchor>, String)],
2277        buffer: &Entity<Buffer>,
2278        cx: &App,
2279    ) -> Vec<(Range<usize>, String)> {
2280        let buffer = buffer.read(cx);
2281        editor_edits
2282            .iter()
2283            .map(|(range, text)| {
2284                (
2285                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2286                    text.clone(),
2287                )
2288            })
2289            .collect()
2290    }
2291
2292    #[ctor::ctor]
2293    fn init_logger() {
2294        zlog::init_test();
2295    }
2296}