zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use db::smol::stream::StreamExt as _;
  12use edit_prediction::DataCollectionState;
  13use futures::channel::mpsc;
  14pub use init::*;
  15use license_detection::LicenseDetectionWatcher;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use arrayvec::ArrayVec;
  20use client::{Client, EditPredictionUsage, UserStore};
  21use cloud_llm_client::{
  22    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
  23    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  24    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
  25    ZED_VERSION_HEADER_NAME,
  26};
  27use collections::{HashMap, HashSet, VecDeque};
  28use futures::AsyncReadExt;
  29use gpui::{
  30    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  31    SharedString, Subscription, Task, actions,
  32};
  33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  34use input_excerpt::excerpt_for_cursor_position;
  35use language::{
  36    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  37};
  38use language_model::{LlmApiToken, RefreshLlmTokenListener};
  39use project::{Project, ProjectPath};
  40use release_channel::AppVersion;
  41use settings::WorktreeId;
  42use std::collections::hash_map;
  43use std::mem;
  44use std::str::FromStr;
  45use std::{
  46    cmp,
  47    fmt::Write,
  48    future::Future,
  49    ops::Range,
  50    path::Path,
  51    rc::Rc,
  52    sync::Arc,
  53    time::{Duration, Instant},
  54};
  55use telemetry_events::EditPredictionRating;
  56use thiserror::Error;
  57use util::ResultExt;
  58use util::rel_path::RelPath;
  59use uuid::Uuid;
  60use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  61use worktree::Worktree;
  62
  63const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  64const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  65const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  66const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  69
  70const MAX_CONTEXT_TOKENS: usize = 150;
  71const MAX_REWRITE_TOKENS: usize = 350;
  72const MAX_EVENT_TOKENS: usize = 500;
  73
  74/// Maximum number of events to track.
  75const MAX_EVENT_COUNT: usize = 16;
  76
  77actions!(
  78    edit_prediction,
  79    [
  80        /// Clears the edit prediction history.
  81        ClearHistory
  82    ]
  83);
  84
  85#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  86pub struct EditPredictionId(Uuid);
  87
  88impl From<EditPredictionId> for gpui::ElementId {
  89    fn from(value: EditPredictionId) -> Self {
  90        gpui::ElementId::Uuid(value.0)
  91    }
  92}
  93
  94impl std::fmt::Display for EditPredictionId {
  95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  96        write!(f, "{}", self.0)
  97    }
  98}
  99
 100struct ZedPredictUpsell;
 101
 102impl Dismissable for ZedPredictUpsell {
 103    const KEY: &'static str = "dismissed-edit-predict-upsell";
 104
 105    fn dismissed() -> bool {
 106        // To make this backwards compatible with older versions of Zed, we
 107        // check if the user has seen the previous Edit Prediction Onboarding
 108        // before, by checking the data collection choice which was written to
 109        // the database once the user clicked on "Accept and Enable"
 110        if KEY_VALUE_STORE
 111            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 112            .log_err()
 113            .is_some_and(|s| s.is_some())
 114        {
 115            return true;
 116        }
 117
 118        KEY_VALUE_STORE
 119            .read_kvp(Self::KEY)
 120            .log_err()
 121            .is_some_and(|s| s.is_some())
 122    }
 123}
 124
 125pub fn should_show_upsell_modal() -> bool {
 126    !ZedPredictUpsell::dismissed()
 127}
 128
 129#[derive(Clone)]
 130struct ZetaGlobal(Entity<Zeta>);
 131
 132impl Global for ZetaGlobal {}
 133
 134#[derive(Clone)]
 135pub struct EditPrediction {
 136    id: EditPredictionId,
 137    path: Arc<Path>,
 138    excerpt_range: Range<usize>,
 139    cursor_offset: usize,
 140    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 141    snapshot: BufferSnapshot,
 142    edit_preview: EditPreview,
 143    input_outline: Arc<str>,
 144    input_events: Arc<str>,
 145    input_excerpt: Arc<str>,
 146    output_excerpt: Arc<str>,
 147    buffer_snapshotted_at: Instant,
 148    response_received_at: Instant,
 149}
 150
 151impl EditPrediction {
 152    fn latency(&self) -> Duration {
 153        self.response_received_at
 154            .duration_since(self.buffer_snapshotted_at)
 155    }
 156
 157    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 158        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 159    }
 160}
 161
 162impl std::fmt::Debug for EditPrediction {
 163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 164        f.debug_struct("EditPrediction")
 165            .field("id", &self.id)
 166            .field("path", &self.path)
 167            .field("edits", &self.edits)
 168            .finish_non_exhaustive()
 169    }
 170}
 171
 172pub struct Zeta {
 173    projects: HashMap<EntityId, ZetaProject>,
 174    client: Arc<Client>,
 175    shown_completions: VecDeque<EditPrediction>,
 176    rated_completions: HashSet<EditPredictionId>,
 177    data_collection_choice: DataCollectionChoice,
 178    discarded_completions: Vec<EditPredictionRejection>,
 179    llm_token: LlmApiToken,
 180    _llm_token_subscription: Subscription,
 181    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 182    update_required: bool,
 183    user_store: Entity<UserStore>,
 184    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 185    discard_completions_debounce_task: Option<Task<()>>,
 186    discard_completions_tx: mpsc::UnboundedSender<()>,
 187}
 188
 189struct ZetaProject {
 190    events: VecDeque<Event>,
 191    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 192}
 193
 194impl Zeta {
 195    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 196        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 197    }
 198
 199    pub fn register(
 200        worktree: Option<Entity<Worktree>>,
 201        client: Arc<Client>,
 202        user_store: Entity<UserStore>,
 203        cx: &mut App,
 204    ) -> Entity<Self> {
 205        let this = Self::global(cx).unwrap_or_else(|| {
 206            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 207            cx.set_global(ZetaGlobal(entity.clone()));
 208            entity
 209        });
 210
 211        this.update(cx, move |this, cx| {
 212            if let Some(worktree) = worktree {
 213                let worktree_id = worktree.read(cx).id();
 214                this.license_detection_watchers
 215                    .entry(worktree_id)
 216                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 217            }
 218        });
 219
 220        this
 221    }
 222
 223    pub fn clear_history(&mut self) {
 224        for zeta_project in self.projects.values_mut() {
 225            zeta_project.events.clear();
 226        }
 227    }
 228
 229    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 230        self.user_store.read(cx).edit_prediction_usage()
 231    }
 232
 233    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 234        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 235        let data_collection_choice = Self::load_data_collection_choice();
 236        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 237        cx.spawn(async move |this, cx| {
 238            while let Some(()) = reject_rx.next().await {
 239                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 240                    .await
 241                    .log_err();
 242            }
 243            anyhow::Ok(())
 244        })
 245        .detach();
 246
 247        Self {
 248            projects: HashMap::default(),
 249            client,
 250            shown_completions: VecDeque::new(),
 251            rated_completions: HashSet::default(),
 252            discarded_completions: Vec::new(),
 253            discard_completions_debounce_task: None,
 254            discard_completions_tx: reject_tx,
 255            data_collection_choice,
 256            llm_token: LlmApiToken::default(),
 257            _llm_token_subscription: cx.subscribe(
 258                &refresh_llm_token_listener,
 259                |this, _listener, _event, cx| {
 260                    let client = this.client.clone();
 261                    let llm_token = this.llm_token.clone();
 262                    cx.spawn(async move |_this, _cx| {
 263                        llm_token.refresh(&client).await?;
 264                        anyhow::Ok(())
 265                    })
 266                    .detach_and_log_err(cx);
 267                },
 268            ),
 269            update_required: false,
 270            license_detection_watchers: HashMap::default(),
 271            user_store,
 272        }
 273    }
 274
 275    fn get_or_init_zeta_project(
 276        &mut self,
 277        project: &Entity<Project>,
 278        cx: &mut Context<Self>,
 279    ) -> &mut ZetaProject {
 280        let project_id = project.entity_id();
 281        match self.projects.entry(project_id) {
 282            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 283            hash_map::Entry::Vacant(entry) => {
 284                cx.observe_release(project, move |this, _, _cx| {
 285                    this.projects.remove(&project_id);
 286                })
 287                .detach();
 288                entry.insert(ZetaProject {
 289                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 290                    registered_buffers: HashMap::default(),
 291                })
 292            }
 293        }
 294    }
 295
 296    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 297        let events = &mut zeta_project.events;
 298
 299        if let Some(Event::BufferChange {
 300            new_snapshot: last_new_snapshot,
 301            timestamp: last_timestamp,
 302            ..
 303        }) = events.back_mut()
 304        {
 305            // Coalesce edits for the same buffer when they happen one after the other.
 306            let Event::BufferChange {
 307                old_snapshot,
 308                new_snapshot,
 309                timestamp,
 310            } = &event;
 311
 312            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 313                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 314                && old_snapshot.version == last_new_snapshot.version
 315            {
 316                *last_new_snapshot = new_snapshot.clone();
 317                *last_timestamp = *timestamp;
 318                return;
 319            }
 320        }
 321
 322        if events.len() >= MAX_EVENT_COUNT {
 323            // These are halved instead of popping to improve prompt caching.
 324            events.drain(..MAX_EVENT_COUNT / 2);
 325        }
 326
 327        events.push_back(event);
 328    }
 329
 330    pub fn register_buffer(
 331        &mut self,
 332        buffer: &Entity<Buffer>,
 333        project: &Entity<Project>,
 334        cx: &mut Context<Self>,
 335    ) {
 336        let zeta_project = self.get_or_init_zeta_project(project, cx);
 337        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 338    }
 339
 340    fn register_buffer_impl<'a>(
 341        zeta_project: &'a mut ZetaProject,
 342        buffer: &Entity<Buffer>,
 343        project: &Entity<Project>,
 344        cx: &mut Context<Self>,
 345    ) -> &'a mut RegisteredBuffer {
 346        let buffer_id = buffer.entity_id();
 347        match zeta_project.registered_buffers.entry(buffer_id) {
 348            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 349            hash_map::Entry::Vacant(entry) => {
 350                let snapshot = buffer.read(cx).snapshot();
 351                let project_entity_id = project.entity_id();
 352                entry.insert(RegisteredBuffer {
 353                    snapshot,
 354                    _subscriptions: [
 355                        cx.subscribe(buffer, {
 356                            let project = project.downgrade();
 357                            move |this, buffer, event, cx| {
 358                                if let language::BufferEvent::Edited = event
 359                                    && let Some(project) = project.upgrade()
 360                                {
 361                                    this.report_changes_for_buffer(&buffer, &project, cx);
 362                                }
 363                            }
 364                        }),
 365                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 366                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 367                            else {
 368                                return;
 369                            };
 370                            zeta_project.registered_buffers.remove(&buffer_id);
 371                        }),
 372                    ],
 373                })
 374            }
 375        }
 376    }
 377
 378    fn request_completion_impl<F, R>(
 379        &mut self,
 380        project: &Entity<Project>,
 381        buffer: &Entity<Buffer>,
 382        cursor: language::Anchor,
 383        cx: &mut Context<Self>,
 384        perform_predict_edits: F,
 385    ) -> Task<Result<Option<EditPrediction>>>
 386    where
 387        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 388        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 389            + Send
 390            + 'static,
 391    {
 392        let buffer = buffer.clone();
 393        let buffer_snapshotted_at = Instant::now();
 394        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 395        let zeta = cx.entity();
 396        let client = self.client.clone();
 397        let llm_token = self.llm_token.clone();
 398        let app_version = AppVersion::global(cx);
 399
 400        let zeta_project = self.get_or_init_zeta_project(project, cx);
 401        let mut events = Vec::with_capacity(zeta_project.events.len());
 402        events.extend(zeta_project.events.iter().cloned());
 403        let events = Arc::new(events);
 404
 405        let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 406            let can_collect_file = self.can_collect_file(file, cx);
 407            let git_info = if can_collect_file {
 408                git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 409            } else {
 410                None
 411            };
 412            (git_info, can_collect_file)
 413        } else {
 414            (None, false)
 415        };
 416
 417        let full_path: Arc<Path> = snapshot
 418            .file()
 419            .map(|f| Arc::from(f.full_path(cx).as_path()))
 420            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 421        let full_path_str = full_path.to_string_lossy().into_owned();
 422        let cursor_point = cursor.to_point(&snapshot);
 423        let cursor_offset = cursor_point.to_offset(&snapshot);
 424        let prompt_for_events = {
 425            let events = events.clone();
 426            move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 427        };
 428        let gather_task = gather_context(
 429            full_path_str,
 430            &snapshot,
 431            cursor_point,
 432            prompt_for_events,
 433            cx,
 434        );
 435
 436        cx.spawn(async move |this, cx| {
 437            let GatherContextOutput {
 438                mut body,
 439                editable_range,
 440                included_events_count,
 441            } = gather_task.await?;
 442            let done_gathering_context_at = Instant::now();
 443
 444            let included_events = &events[events.len() - included_events_count..events.len()];
 445            body.can_collect_data = can_collect_file
 446                && this
 447                    .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
 448                    .unwrap_or(false);
 449            if body.can_collect_data {
 450                body.git_info = git_info;
 451            }
 452
 453            log::debug!(
 454                "Events:\n{}\nExcerpt:\n{:?}",
 455                body.input_events,
 456                body.input_excerpt
 457            );
 458
 459            let input_outline = body.outline.clone().unwrap_or_default();
 460            let input_events = body.input_events.clone();
 461            let input_excerpt = body.input_excerpt.clone();
 462
 463            let response = perform_predict_edits(PerformPredictEditsParams {
 464                client,
 465                llm_token,
 466                app_version,
 467                body,
 468            })
 469            .await;
 470            let (response, usage) = match response {
 471                Ok(response) => response,
 472                Err(err) => {
 473                    if err.is::<ZedUpdateRequiredError>() {
 474                        cx.update(|cx| {
 475                            zeta.update(cx, |zeta, _cx| {
 476                                zeta.update_required = true;
 477                            });
 478
 479                            let error_message: SharedString = err.to_string().into();
 480                            show_app_notification(
 481                                NotificationId::unique::<ZedUpdateRequiredError>(),
 482                                cx,
 483                                move |cx| {
 484                                    cx.new(|cx| {
 485                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 486                                            .with_link_button(
 487                                                "Update Zed",
 488                                                "https://zed.dev/releases",
 489                                            )
 490                                    })
 491                                },
 492                            );
 493                        })
 494                        .ok();
 495                    }
 496
 497                    return Err(err);
 498                }
 499            };
 500
 501            let received_response_at = Instant::now();
 502            log::debug!("completion response: {}", &response.output_excerpt);
 503
 504            if let Some(usage) = usage {
 505                this.update(cx, |this, cx| {
 506                    this.user_store.update(cx, |user_store, cx| {
 507                        user_store.update_edit_prediction_usage(usage, cx);
 508                    });
 509                })
 510                .ok();
 511            }
 512
 513            let edit_prediction = Self::process_completion_response(
 514                response,
 515                buffer,
 516                &snapshot,
 517                editable_range,
 518                cursor_offset,
 519                full_path,
 520                input_outline,
 521                input_events,
 522                input_excerpt,
 523                buffer_snapshotted_at,
 524                cx,
 525            )
 526            .await;
 527
 528            let finished_at = Instant::now();
 529
 530            // record latency for ~1% of requests
 531            if rand::random::<u8>() <= 2 {
 532                telemetry::event!(
 533                    "Edit Prediction Request",
 534                    context_latency = done_gathering_context_at
 535                        .duration_since(buffer_snapshotted_at)
 536                        .as_millis(),
 537                    request_latency = received_response_at
 538                        .duration_since(done_gathering_context_at)
 539                        .as_millis(),
 540                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 541                );
 542            }
 543
 544            edit_prediction
 545        })
 546    }
 547
 548    #[cfg(any(test, feature = "test-support"))]
 549    pub fn fake_completion(
 550        &mut self,
 551        project: &Entity<Project>,
 552        buffer: &Entity<Buffer>,
 553        position: language::Anchor,
 554        response: PredictEditsResponse,
 555        cx: &mut Context<Self>,
 556    ) -> Task<Result<Option<EditPrediction>>> {
 557        self.request_completion_impl(project, buffer, position, cx, |_params| {
 558            std::future::ready(Ok((response, None)))
 559        })
 560    }
 561
 562    pub fn request_completion(
 563        &mut self,
 564        project: &Entity<Project>,
 565        buffer: &Entity<Buffer>,
 566        position: language::Anchor,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<Option<EditPrediction>>> {
 569        self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
 570    }
 571
 572    pub fn perform_predict_edits(
 573        params: PerformPredictEditsParams,
 574    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 575        async move {
 576            let PerformPredictEditsParams {
 577                client,
 578                llm_token,
 579                app_version,
 580                body,
 581                ..
 582            } = params;
 583
 584            let http_client = client.http_client();
 585            let mut token = llm_token.acquire(&client).await?;
 586            let mut did_retry = false;
 587
 588            loop {
 589                let request_builder = http_client::Request::builder().method(Method::POST);
 590                let request_builder =
 591                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 592                        request_builder.uri(predict_edits_url)
 593                    } else {
 594                        request_builder.uri(
 595                            http_client
 596                                .build_zed_llm_url("/predict_edits/v2", &[])?
 597                                .as_ref(),
 598                        )
 599                    };
 600                let request = request_builder
 601                    .header("Content-Type", "application/json")
 602                    .header("Authorization", format!("Bearer {}", token))
 603                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 604                    .body(serde_json::to_string(&body)?.into())?;
 605
 606                let mut response = http_client.send(request).await?;
 607
 608                if let Some(minimum_required_version) = response
 609                    .headers()
 610                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 611                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 612                {
 613                    anyhow::ensure!(
 614                        app_version >= minimum_required_version,
 615                        ZedUpdateRequiredError {
 616                            minimum_version: minimum_required_version
 617                        }
 618                    );
 619                }
 620
 621                if response.status().is_success() {
 622                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 623
 624                    let mut body = String::new();
 625                    response.body_mut().read_to_string(&mut body).await?;
 626                    return Ok((serde_json::from_str(&body)?, usage));
 627                } else if !did_retry
 628                    && response
 629                        .headers()
 630                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 631                        .is_some()
 632                {
 633                    did_retry = true;
 634                    token = llm_token.refresh(&client).await?;
 635                } else {
 636                    let mut body = String::new();
 637                    response.body_mut().read_to_string(&mut body).await?;
 638                    anyhow::bail!(
 639                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 640                        response.status(),
 641                        body
 642                    );
 643                }
 644            }
 645        }
 646    }
 647
 648    fn accept_edit_prediction(
 649        &mut self,
 650        request_id: EditPredictionId,
 651        cx: &mut Context<Self>,
 652    ) -> Task<Result<()>> {
 653        let client = self.client.clone();
 654        let llm_token = self.llm_token.clone();
 655        let app_version = AppVersion::global(cx);
 656        cx.spawn(async move |this, cx| {
 657            let http_client = client.http_client();
 658            let mut response = llm_token_retry(&llm_token, &client, |token| {
 659                let request_builder = http_client::Request::builder().method(Method::POST);
 660                let request_builder =
 661                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 662                        request_builder.uri(accept_prediction_url)
 663                    } else {
 664                        request_builder.uri(
 665                            http_client
 666                                .build_zed_llm_url("/predict_edits/accept", &[])?
 667                                .as_ref(),
 668                        )
 669                    };
 670                Ok(request_builder
 671                    .header("Content-Type", "application/json")
 672                    .header("Authorization", format!("Bearer {}", token))
 673                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 674                    .body(
 675                        serde_json::to_string(&AcceptEditPredictionBody {
 676                            request_id: request_id.0.to_string(),
 677                        })?
 678                        .into(),
 679                    )?)
 680            })
 681            .await?;
 682
 683            if let Some(minimum_required_version) = response
 684                .headers()
 685                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 686                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 687                && app_version < minimum_required_version
 688            {
 689                return Err(anyhow!(ZedUpdateRequiredError {
 690                    minimum_version: minimum_required_version
 691                }));
 692            }
 693
 694            if response.status().is_success() {
 695                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 696                    this.update(cx, |this, cx| {
 697                        this.user_store.update(cx, |user_store, cx| {
 698                            user_store.update_edit_prediction_usage(usage, cx);
 699                        });
 700                    })?;
 701                }
 702
 703                Ok(())
 704            } else {
 705                let mut body = String::new();
 706                response.body_mut().read_to_string(&mut body).await?;
 707                Err(anyhow!(
 708                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 709                    response.status(),
 710                    body
 711                ))
 712            }
 713        })
 714    }
 715
 716    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 717        let client = self.client.clone();
 718        let llm_token = self.llm_token.clone();
 719        let app_version = AppVersion::global(cx);
 720        let last_rejection = self.discarded_completions.last().cloned();
 721        let body = serde_json::to_string(&RejectEditPredictionsBody {
 722            rejections: self.discarded_completions.clone(),
 723        })
 724        .ok();
 725
 726        let Some(last_rejection) = last_rejection else {
 727            return Task::ready(anyhow::Ok(()));
 728        };
 729
 730        cx.spawn(async move |this, cx| {
 731            let http_client = client.http_client();
 732            let mut response = llm_token_retry(&llm_token, &client, |token| {
 733                let request_builder = http_client::Request::builder().method(Method::POST);
 734                let request_builder = request_builder.uri(
 735                    http_client
 736                        .build_zed_llm_url("/predict_edits/reject", &[])?
 737                        .as_ref(),
 738                );
 739                Ok(request_builder
 740                    .header("Content-Type", "application/json")
 741                    .header("Authorization", format!("Bearer {}", token))
 742                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 743                    .body(
 744                        body.as_ref()
 745                            .context("failed to serialize body")?
 746                            .clone()
 747                            .into(),
 748                    )?)
 749            })
 750            .await?;
 751
 752            if let Some(minimum_required_version) = response
 753                .headers()
 754                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 755                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 756                && app_version < minimum_required_version
 757            {
 758                return Err(anyhow!(ZedUpdateRequiredError {
 759                    minimum_version: minimum_required_version
 760                }));
 761            }
 762
 763            if response.status().is_success() {
 764                this.update(cx, |this, _| {
 765                    if let Some(ix) = this
 766                        .discarded_completions
 767                        .iter()
 768                        .position(|rejection| rejection.request_id == last_rejection.request_id)
 769                    {
 770                        this.discarded_completions.drain(..ix + 1);
 771                    }
 772                })
 773            } else {
 774                let mut body = String::new();
 775                response.body_mut().read_to_string(&mut body).await?;
 776                Err(anyhow!(
 777                    "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
 778                    response.status(),
 779                    body
 780                ))
 781            }
 782        })
 783    }
 784
 785    fn process_completion_response(
 786        prediction_response: PredictEditsResponse,
 787        buffer: Entity<Buffer>,
 788        snapshot: &BufferSnapshot,
 789        editable_range: Range<usize>,
 790        cursor_offset: usize,
 791        path: Arc<Path>,
 792        input_outline: String,
 793        input_events: String,
 794        input_excerpt: String,
 795        buffer_snapshotted_at: Instant,
 796        cx: &AsyncApp,
 797    ) -> Task<Result<Option<EditPrediction>>> {
 798        let snapshot = snapshot.clone();
 799        let request_id = prediction_response.request_id;
 800        let output_excerpt = prediction_response.output_excerpt;
 801        cx.spawn(async move |cx| {
 802            let output_excerpt: Arc<str> = output_excerpt.into();
 803
 804            let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
 805                .background_spawn({
 806                    let output_excerpt = output_excerpt.clone();
 807                    let editable_range = editable_range.clone();
 808                    let snapshot = snapshot.clone();
 809                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 810                })
 811                .await?
 812                .into();
 813
 814            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 815                let edits = edits.clone();
 816                move |buffer, cx| {
 817                    let new_snapshot = buffer.snapshot();
 818                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
 819                        edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
 820                            .into();
 821                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 822                }
 823            })?
 824            else {
 825                return anyhow::Ok(None);
 826            };
 827
 828            let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
 829
 830            let edit_preview = edit_preview.await;
 831
 832            Ok(Some(EditPrediction {
 833                id: EditPredictionId(request_id),
 834                path,
 835                excerpt_range: editable_range,
 836                cursor_offset,
 837                edits,
 838                edit_preview,
 839                snapshot,
 840                input_outline: input_outline.into(),
 841                input_events: input_events.into(),
 842                input_excerpt: input_excerpt.into(),
 843                output_excerpt,
 844                buffer_snapshotted_at,
 845                response_received_at: Instant::now(),
 846            }))
 847        })
 848    }
 849
 850    fn parse_edits(
 851        output_excerpt: Arc<str>,
 852        editable_range: Range<usize>,
 853        snapshot: &BufferSnapshot,
 854    ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
 855        let content = output_excerpt.replace(CURSOR_MARKER, "");
 856
 857        let start_markers = content
 858            .match_indices(EDITABLE_REGION_START_MARKER)
 859            .collect::<Vec<_>>();
 860        anyhow::ensure!(
 861            start_markers.len() == 1,
 862            "expected exactly one start marker, found {}",
 863            start_markers.len()
 864        );
 865
 866        let end_markers = content
 867            .match_indices(EDITABLE_REGION_END_MARKER)
 868            .collect::<Vec<_>>();
 869        anyhow::ensure!(
 870            end_markers.len() == 1,
 871            "expected exactly one end marker, found {}",
 872            end_markers.len()
 873        );
 874
 875        let sof_markers = content
 876            .match_indices(START_OF_FILE_MARKER)
 877            .collect::<Vec<_>>();
 878        anyhow::ensure!(
 879            sof_markers.len() <= 1,
 880            "expected at most one start-of-file marker, found {}",
 881            sof_markers.len()
 882        );
 883
 884        let codefence_start = start_markers[0].0;
 885        let content = &content[codefence_start..];
 886
 887        let newline_ix = content.find('\n').context("could not find newline")?;
 888        let content = &content[newline_ix + 1..];
 889
 890        let codefence_end = content
 891            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 892            .context("could not find end marker")?;
 893        let new_text = &content[..codefence_end];
 894
 895        let old_text = snapshot
 896            .text_for_range(editable_range.clone())
 897            .collect::<String>();
 898
 899        Ok(Self::compute_edits(
 900            old_text,
 901            new_text,
 902            editable_range.start,
 903            snapshot,
 904        ))
 905    }
 906
 907    pub fn compute_edits(
 908        old_text: String,
 909        new_text: &str,
 910        offset: usize,
 911        snapshot: &BufferSnapshot,
 912    ) -> Vec<(Range<Anchor>, Arc<str>)> {
 913        text_diff(&old_text, new_text)
 914            .into_iter()
 915            .map(|(mut old_range, new_text)| {
 916                old_range.start += offset;
 917                old_range.end += offset;
 918
 919                let prefix_len = common_prefix(
 920                    snapshot.chars_for_range(old_range.clone()),
 921                    new_text.chars(),
 922                );
 923                old_range.start += prefix_len;
 924
 925                let suffix_len = common_prefix(
 926                    snapshot.reversed_chars_for_range(old_range.clone()),
 927                    new_text[prefix_len..].chars().rev(),
 928                );
 929                old_range.end = old_range.end.saturating_sub(suffix_len);
 930
 931                let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
 932                let range = if old_range.is_empty() {
 933                    let anchor = snapshot.anchor_after(old_range.start);
 934                    anchor..anchor
 935                } else {
 936                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 937                };
 938                (range, new_text)
 939            })
 940            .collect()
 941    }
 942
 943    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 944        self.rated_completions.contains(&completion_id)
 945    }
 946
 947    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 948        self.shown_completions.push_front(completion.clone());
 949        if self.shown_completions.len() > 50 {
 950            let completion = self.shown_completions.pop_back().unwrap();
 951            self.rated_completions.remove(&completion.id);
 952        }
 953        cx.notify();
 954    }
 955
 956    pub fn rate_completion(
 957        &mut self,
 958        completion: &EditPrediction,
 959        rating: EditPredictionRating,
 960        feedback: String,
 961        cx: &mut Context<Self>,
 962    ) {
 963        self.rated_completions.insert(completion.id);
 964        telemetry::event!(
 965            "Edit Prediction Rated",
 966            rating,
 967            input_events = completion.input_events,
 968            input_excerpt = completion.input_excerpt,
 969            input_outline = completion.input_outline,
 970            output_excerpt = completion.output_excerpt,
 971            feedback
 972        );
 973        self.client.telemetry().flush_events().detach();
 974        cx.notify();
 975    }
 976
 977    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 978        self.shown_completions.iter()
 979    }
 980
 981    pub fn shown_completions_len(&self) -> usize {
 982        self.shown_completions.len()
 983    }
 984
 985    fn report_changes_for_buffer(
 986        &mut self,
 987        buffer: &Entity<Buffer>,
 988        project: &Entity<Project>,
 989        cx: &mut Context<Self>,
 990    ) -> BufferSnapshot {
 991        let zeta_project = self.get_or_init_zeta_project(project, cx);
 992        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 993
 994        let new_snapshot = buffer.read(cx).snapshot();
 995        if new_snapshot.version != registered_buffer.snapshot.version {
 996            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 997            Self::push_event(
 998                zeta_project,
 999                Event::BufferChange {
1000                    old_snapshot,
1001                    new_snapshot: new_snapshot.clone(),
1002                    timestamp: Instant::now(),
1003                },
1004            );
1005        }
1006
1007        new_snapshot
1008    }
1009
1010    fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1011        self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1012    }
1013
1014    fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1015        if !self.data_collection_choice.is_enabled() {
1016            return false;
1017        }
1018        let mut last_checked_file = None;
1019        for event in events {
1020            match event {
1021                Event::BufferChange {
1022                    old_snapshot,
1023                    new_snapshot,
1024                    ..
1025                } => {
1026                    if let Some(old_file) = old_snapshot.file()
1027                        && let Some(new_file) = new_snapshot.file()
1028                    {
1029                        if let Some(last_checked_file) = last_checked_file
1030                            && Arc::ptr_eq(last_checked_file, old_file)
1031                            && Arc::ptr_eq(last_checked_file, new_file)
1032                        {
1033                            continue;
1034                        }
1035                        if !self.can_collect_file(old_file, cx) {
1036                            return false;
1037                        }
1038                        if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1039                        {
1040                            return false;
1041                        }
1042                        last_checked_file = Some(new_file);
1043                    } else {
1044                        return false;
1045                    }
1046                }
1047            }
1048        }
1049        true
1050    }
1051
1052    fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1053        if !file.is_local() || file.is_private() {
1054            return false;
1055        }
1056        self.license_detection_watchers
1057            .get(&file.worktree_id(cx))
1058            .is_some_and(|watcher| watcher.is_project_open_source())
1059    }
1060
1061    fn load_data_collection_choice() -> DataCollectionChoice {
1062        let choice = KEY_VALUE_STORE
1063            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1064            .log_err()
1065            .flatten();
1066
1067        match choice.as_deref() {
1068            Some("true") => DataCollectionChoice::Enabled,
1069            Some("false") => DataCollectionChoice::Disabled,
1070            Some(_) => {
1071                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1072                DataCollectionChoice::NotAnswered
1073            }
1074            None => DataCollectionChoice::NotAnswered,
1075        }
1076    }
1077
1078    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1079        self.data_collection_choice = self.data_collection_choice.toggle();
1080        let new_choice = self.data_collection_choice;
1081        db::write_and_log(cx, move || {
1082            KEY_VALUE_STORE.write_kvp(
1083                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1084                new_choice.is_enabled().to_string(),
1085            )
1086        });
1087    }
1088
1089    fn discard_completion(
1090        &mut self,
1091        completion_id: EditPredictionId,
1092        was_shown: bool,
1093        cx: &mut Context<Self>,
1094    ) {
1095        self.discarded_completions.push(EditPredictionRejection {
1096            request_id: completion_id.to_string(),
1097            was_shown,
1098        });
1099
1100        let reached_request_limit =
1101            self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1102        let discard_completions_tx = self.discard_completions_tx.clone();
1103        self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1104            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1105            if !reached_request_limit {
1106                cx.background_executor()
1107                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1108                    .await;
1109            }
1110            discard_completions_tx.unbounded_send(()).log_err();
1111        }));
1112    }
1113}
1114
1115pub struct PerformPredictEditsParams {
1116    pub client: Arc<Client>,
1117    pub llm_token: LlmApiToken,
1118    pub app_version: SemanticVersion,
1119    pub body: PredictEditsBody,
1120}
1121
1122#[derive(Error, Debug)]
1123#[error(
1124    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1125)]
1126pub struct ZedUpdateRequiredError {
1127    minimum_version: SemanticVersion,
1128}
1129
1130fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1131    a.zip(b)
1132        .take_while(|(a, b)| a == b)
1133        .map(|(a, _)| a.len_utf8())
1134        .sum()
1135}
1136
1137fn git_info_for_file(
1138    project: &Entity<Project>,
1139    project_path: &ProjectPath,
1140    cx: &App,
1141) -> Option<PredictEditsGitInfo> {
1142    let git_store = project.read(cx).git_store().read(cx);
1143    if let Some((repository, _repo_path)) =
1144        git_store.repository_and_path_for_project_path(project_path, cx)
1145    {
1146        let repository = repository.read(cx);
1147        let head_sha = repository
1148            .head_commit
1149            .as_ref()
1150            .map(|head_commit| head_commit.sha.to_string());
1151        let remote_origin_url = repository.remote_origin_url.clone();
1152        let remote_upstream_url = repository.remote_upstream_url.clone();
1153        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1154            return None;
1155        }
1156        Some(PredictEditsGitInfo {
1157            head_sha,
1158            remote_origin_url,
1159            remote_upstream_url,
1160        })
1161    } else {
1162        None
1163    }
1164}
1165
1166pub struct GatherContextOutput {
1167    pub body: PredictEditsBody,
1168    pub editable_range: Range<usize>,
1169    pub included_events_count: usize,
1170}
1171
1172pub fn gather_context(
1173    full_path_str: String,
1174    snapshot: &BufferSnapshot,
1175    cursor_point: language::Point,
1176    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1177    cx: &App,
1178) -> Task<Result<GatherContextOutput>> {
1179    cx.background_spawn({
1180        let snapshot = snapshot.clone();
1181        async move {
1182            let input_excerpt = excerpt_for_cursor_position(
1183                cursor_point,
1184                &full_path_str,
1185                &snapshot,
1186                MAX_REWRITE_TOKENS,
1187                MAX_CONTEXT_TOKENS,
1188            );
1189            let (input_events, included_events_count) = prompt_for_events();
1190            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1191
1192            let body = PredictEditsBody {
1193                input_events,
1194                input_excerpt: input_excerpt.prompt,
1195                can_collect_data: false,
1196                diagnostic_groups: None,
1197                git_info: None,
1198                outline: None,
1199                speculated_output: None,
1200            };
1201
1202            Ok(GatherContextOutput {
1203                body,
1204                editable_range,
1205                included_events_count,
1206            })
1207        }
1208    })
1209}
1210
1211fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1212    let mut result = String::new();
1213    for (ix, event) in events.iter().rev().enumerate() {
1214        let event_string = event.to_prompt();
1215        let event_tokens = guess_token_count(event_string.len());
1216        if event_tokens > remaining_tokens {
1217            return (result, ix);
1218        }
1219
1220        if !result.is_empty() {
1221            result.insert_str(0, "\n\n");
1222        }
1223        result.insert_str(0, &event_string);
1224        remaining_tokens -= event_tokens;
1225    }
1226    return (result, events.len());
1227}
1228
1229struct RegisteredBuffer {
1230    snapshot: BufferSnapshot,
1231    _subscriptions: [gpui::Subscription; 2],
1232}
1233
1234#[derive(Clone)]
1235pub enum Event {
1236    BufferChange {
1237        old_snapshot: BufferSnapshot,
1238        new_snapshot: BufferSnapshot,
1239        timestamp: Instant,
1240    },
1241}
1242
1243impl Event {
1244    fn to_prompt(&self) -> String {
1245        match self {
1246            Event::BufferChange {
1247                old_snapshot,
1248                new_snapshot,
1249                ..
1250            } => {
1251                let mut prompt = String::new();
1252
1253                let old_path = old_snapshot
1254                    .file()
1255                    .map(|f| f.path().as_ref())
1256                    .unwrap_or(RelPath::unix("untitled").unwrap());
1257                let new_path = new_snapshot
1258                    .file()
1259                    .map(|f| f.path().as_ref())
1260                    .unwrap_or(RelPath::unix("untitled").unwrap());
1261                if old_path != new_path {
1262                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1263                }
1264
1265                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1266                if !diff.is_empty() {
1267                    write!(
1268                        prompt,
1269                        "User edited {:?}:\n```diff\n{}\n```",
1270                        new_path, diff
1271                    )
1272                    .unwrap();
1273                }
1274
1275                prompt
1276            }
1277        }
1278    }
1279}
1280
1281#[derive(Debug, Clone)]
1282struct CurrentEditPrediction {
1283    buffer_id: EntityId,
1284    completion: EditPrediction,
1285    was_shown: bool,
1286}
1287
1288impl CurrentEditPrediction {
1289    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1290        if self.buffer_id != old_completion.buffer_id {
1291            return true;
1292        }
1293
1294        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1295            return true;
1296        };
1297        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1298            return false;
1299        };
1300
1301        if old_edits.len() == 1 && new_edits.len() == 1 {
1302            let (old_range, old_text) = &old_edits[0];
1303            let (new_range, new_text) = &new_edits[0];
1304            new_range == old_range && new_text.starts_with(old_text.as_ref())
1305        } else {
1306            true
1307        }
1308    }
1309}
1310
1311struct PendingCompletion {
1312    id: usize,
1313    _task: Task<()>,
1314}
1315
1316#[derive(Debug, Clone, Copy)]
1317pub enum DataCollectionChoice {
1318    NotAnswered,
1319    Enabled,
1320    Disabled,
1321}
1322
1323impl DataCollectionChoice {
1324    pub fn is_enabled(self) -> bool {
1325        match self {
1326            Self::Enabled => true,
1327            Self::NotAnswered | Self::Disabled => false,
1328        }
1329    }
1330
1331    pub fn is_answered(self) -> bool {
1332        match self {
1333            Self::Enabled | Self::Disabled => true,
1334            Self::NotAnswered => false,
1335        }
1336    }
1337
1338    #[must_use]
1339    pub fn toggle(&self) -> DataCollectionChoice {
1340        match self {
1341            Self::Enabled => Self::Disabled,
1342            Self::Disabled => Self::Enabled,
1343            Self::NotAnswered => Self::Enabled,
1344        }
1345    }
1346}
1347
1348impl From<bool> for DataCollectionChoice {
1349    fn from(value: bool) -> Self {
1350        match value {
1351            true => DataCollectionChoice::Enabled,
1352            false => DataCollectionChoice::Disabled,
1353        }
1354    }
1355}
1356
1357async fn llm_token_retry(
1358    llm_token: &LlmApiToken,
1359    client: &Arc<Client>,
1360    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1361) -> Result<Response<AsyncBody>> {
1362    let mut did_retry = false;
1363    let http_client = client.http_client();
1364    let mut token = llm_token.acquire(client).await?;
1365    loop {
1366        let request = build_request(token.clone())?;
1367        let response = http_client.send(request).await?;
1368
1369        if !did_retry
1370            && !response.status().is_success()
1371            && response
1372                .headers()
1373                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1374                .is_some()
1375        {
1376            did_retry = true;
1377            token = llm_token.refresh(client).await?;
1378            continue;
1379        }
1380
1381        return Ok(response);
1382    }
1383}
1384
1385pub struct ZetaEditPredictionProvider {
1386    zeta: Entity<Zeta>,
1387    singleton_buffer: Option<Entity<Buffer>>,
1388    pending_completions: ArrayVec<PendingCompletion, 2>,
1389    next_pending_completion_id: usize,
1390    current_completion: Option<CurrentEditPrediction>,
1391    last_request_timestamp: Instant,
1392    project: Entity<Project>,
1393}
1394
1395impl ZetaEditPredictionProvider {
1396    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1397
1398    pub fn new(
1399        zeta: Entity<Zeta>,
1400        project: Entity<Project>,
1401        singleton_buffer: Option<Entity<Buffer>>,
1402    ) -> Self {
1403        Self {
1404            zeta,
1405            singleton_buffer,
1406            pending_completions: ArrayVec::new(),
1407            next_pending_completion_id: 0,
1408            current_completion: None,
1409            last_request_timestamp: Instant::now(),
1410            project,
1411        }
1412    }
1413}
1414
1415impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1416    fn name() -> &'static str {
1417        "zed-predict"
1418    }
1419
1420    fn display_name() -> &'static str {
1421        "Zed's Edit Predictions"
1422    }
1423
1424    fn show_completions_in_menu() -> bool {
1425        true
1426    }
1427
1428    fn show_tab_accept_marker() -> bool {
1429        true
1430    }
1431
1432    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1433        if let Some(buffer) = &self.singleton_buffer
1434            && let Some(file) = buffer.read(cx).file()
1435        {
1436            let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1437            if self.zeta.read(cx).data_collection_choice.is_enabled() {
1438                DataCollectionState::Enabled {
1439                    is_project_open_source,
1440                }
1441            } else {
1442                DataCollectionState::Disabled {
1443                    is_project_open_source,
1444                }
1445            }
1446        } else {
1447            return DataCollectionState::Disabled {
1448                is_project_open_source: false,
1449            };
1450        }
1451    }
1452
1453    fn toggle_data_collection(&mut self, cx: &mut App) {
1454        self.zeta
1455            .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1456    }
1457
1458    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1459        self.zeta.read(cx).usage(cx)
1460    }
1461
1462    fn is_enabled(
1463        &self,
1464        _buffer: &Entity<Buffer>,
1465        _cursor_position: language::Anchor,
1466        _cx: &App,
1467    ) -> bool {
1468        true
1469    }
1470    fn is_refreshing(&self) -> bool {
1471        !self.pending_completions.is_empty()
1472    }
1473
1474    fn refresh(
1475        &mut self,
1476        buffer: Entity<Buffer>,
1477        position: language::Anchor,
1478        _debounce: bool,
1479        cx: &mut Context<Self>,
1480    ) {
1481        if self.zeta.read(cx).update_required {
1482            return;
1483        }
1484
1485        if self
1486            .zeta
1487            .read(cx)
1488            .user_store
1489            .read_with(cx, |user_store, _cx| {
1490                user_store.account_too_young() || user_store.has_overdue_invoices()
1491            })
1492        {
1493            return;
1494        }
1495
1496        if let Some(current_completion) = self.current_completion.as_ref() {
1497            let snapshot = buffer.read(cx).snapshot();
1498            if current_completion
1499                .completion
1500                .interpolate(&snapshot)
1501                .is_some()
1502            {
1503                return;
1504            }
1505        }
1506
1507        let pending_completion_id = self.next_pending_completion_id;
1508        self.next_pending_completion_id += 1;
1509        let last_request_timestamp = self.last_request_timestamp;
1510
1511        let project = self.project.clone();
1512        let task = cx.spawn(async move |this, cx| {
1513            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1514                .checked_duration_since(Instant::now())
1515            {
1516                cx.background_executor().timer(timeout).await;
1517            }
1518
1519            let completion_request = this.update(cx, |this, cx| {
1520                this.last_request_timestamp = Instant::now();
1521                this.zeta.update(cx, |zeta, cx| {
1522                    zeta.request_completion(&project, &buffer, position, cx)
1523                })
1524            });
1525
1526            let completion = match completion_request {
1527                Ok(completion_request) => {
1528                    let completion_request = completion_request.await;
1529                    completion_request.map(|c| {
1530                        c.map(|completion| CurrentEditPrediction {
1531                            buffer_id: buffer.entity_id(),
1532                            completion,
1533                            was_shown: false,
1534                        })
1535                    })
1536                }
1537                Err(error) => Err(error),
1538            };
1539            let Some(new_completion) = completion
1540                .context("edit prediction failed")
1541                .log_err()
1542                .flatten()
1543            else {
1544                this.update(cx, |this, cx| {
1545                    if this.pending_completions[0].id == pending_completion_id {
1546                        this.pending_completions.remove(0);
1547                    } else {
1548                        this.pending_completions.clear();
1549                    }
1550
1551                    cx.notify();
1552                })
1553                .ok();
1554                return;
1555            };
1556
1557            this.update(cx, |this, cx| {
1558                if this.pending_completions[0].id == pending_completion_id {
1559                    this.pending_completions.remove(0);
1560                } else {
1561                    this.pending_completions.clear();
1562                }
1563
1564                if let Some(old_completion) = this.current_completion.as_ref() {
1565                    let snapshot = buffer.read(cx).snapshot();
1566                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1567                        this.zeta.update(cx, |zeta, cx| {
1568                            zeta.completion_shown(&new_completion.completion, cx);
1569                        });
1570                        this.current_completion = Some(new_completion);
1571                    }
1572                } else {
1573                    this.zeta.update(cx, |zeta, cx| {
1574                        zeta.completion_shown(&new_completion.completion, cx);
1575                    });
1576                    this.current_completion = Some(new_completion);
1577                }
1578
1579                cx.notify();
1580            })
1581            .ok();
1582        });
1583
1584        // We always maintain at most two pending completions. When we already
1585        // have two, we replace the newest one.
1586        if self.pending_completions.len() <= 1 {
1587            self.pending_completions.push(PendingCompletion {
1588                id: pending_completion_id,
1589                _task: task,
1590            });
1591        } else if self.pending_completions.len() == 2 {
1592            self.pending_completions.pop();
1593            self.pending_completions.push(PendingCompletion {
1594                id: pending_completion_id,
1595                _task: task,
1596            });
1597        }
1598    }
1599
1600    fn cycle(
1601        &mut self,
1602        _buffer: Entity<Buffer>,
1603        _cursor_position: language::Anchor,
1604        _direction: edit_prediction::Direction,
1605        _cx: &mut Context<Self>,
1606    ) {
1607        // Right now we don't support cycling.
1608    }
1609
1610    fn accept(&mut self, cx: &mut Context<Self>) {
1611        let completion_id = self
1612            .current_completion
1613            .as_ref()
1614            .map(|completion| completion.completion.id);
1615        if let Some(completion_id) = completion_id {
1616            self.zeta
1617                .update(cx, |zeta, cx| {
1618                    zeta.accept_edit_prediction(completion_id, cx)
1619                })
1620                .detach();
1621        }
1622        self.pending_completions.clear();
1623    }
1624
1625    fn discard(&mut self, cx: &mut Context<Self>) {
1626        self.pending_completions.clear();
1627        if let Some(completion) = self.current_completion.take() {
1628            self.zeta.update(cx, |zeta, cx| {
1629                zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1630            });
1631        }
1632    }
1633
1634    fn did_show(&mut self, _cx: &mut Context<Self>) {
1635        if let Some(current_completion) = self.current_completion.as_mut() {
1636            current_completion.was_shown = true;
1637        }
1638    }
1639
1640    fn suggest(
1641        &mut self,
1642        buffer: &Entity<Buffer>,
1643        cursor_position: language::Anchor,
1644        cx: &mut Context<Self>,
1645    ) -> Option<edit_prediction::EditPrediction> {
1646        let CurrentEditPrediction {
1647            buffer_id,
1648            completion,
1649            ..
1650        } = self.current_completion.as_mut()?;
1651
1652        // Invalidate previous completion if it was generated for a different buffer.
1653        if *buffer_id != buffer.entity_id() {
1654            self.current_completion.take();
1655            return None;
1656        }
1657
1658        let buffer = buffer.read(cx);
1659        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1660            self.current_completion.take();
1661            return None;
1662        };
1663
1664        let cursor_row = cursor_position.to_point(buffer).row;
1665        let (closest_edit_ix, (closest_edit_range, _)) =
1666            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1667                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1668                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1669                cmp::min(distance_from_start, distance_from_end)
1670            })?;
1671
1672        let mut edit_start_ix = closest_edit_ix;
1673        for (range, _) in edits[..edit_start_ix].iter().rev() {
1674            let distance_from_closest_edit =
1675                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1676            if distance_from_closest_edit <= 1 {
1677                edit_start_ix -= 1;
1678            } else {
1679                break;
1680            }
1681        }
1682
1683        let mut edit_end_ix = closest_edit_ix + 1;
1684        for (range, _) in &edits[edit_end_ix..] {
1685            let distance_from_closest_edit =
1686                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1687            if distance_from_closest_edit <= 1 {
1688                edit_end_ix += 1;
1689            } else {
1690                break;
1691            }
1692        }
1693
1694        Some(edit_prediction::EditPrediction::Local {
1695            id: Some(completion.id.to_string().into()),
1696            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1697            edit_preview: Some(completion.edit_preview.clone()),
1698        })
1699    }
1700}
1701
1702/// Typical number of string bytes per token for the purposes of limiting model input. This is
1703/// intentionally low to err on the side of underestimating limits.
1704const BYTES_PER_TOKEN_GUESS: usize = 3;
1705
1706fn guess_token_count(bytes: usize) -> usize {
1707    bytes / BYTES_PER_TOKEN_GUESS
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712    use client::test::FakeServer;
1713    use clock::{FakeSystemClock, ReplicaId};
1714    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1715    use gpui::TestAppContext;
1716    use http_client::FakeHttpClient;
1717    use indoc::indoc;
1718    use language::Point;
1719    use parking_lot::Mutex;
1720    use serde_json::json;
1721    use settings::SettingsStore;
1722    use util::{path, rel_path::rel_path};
1723
1724    use super::*;
1725
1726    const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1727
1728    #[gpui::test]
1729    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1730        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1731        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1732            to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1733        });
1734
1735        let edit_preview = cx
1736            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1737            .await;
1738
1739        let completion = EditPrediction {
1740            edits,
1741            edit_preview,
1742            path: Path::new("").into(),
1743            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1744            id: EditPredictionId(Uuid::new_v4()),
1745            excerpt_range: 0..0,
1746            cursor_offset: 0,
1747            input_outline: "".into(),
1748            input_events: "".into(),
1749            input_excerpt: "".into(),
1750            output_excerpt: "".into(),
1751            buffer_snapshotted_at: Instant::now(),
1752            response_received_at: Instant::now(),
1753        };
1754
1755        cx.update(|cx| {
1756            assert_eq!(
1757                from_completion_edits(
1758                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1759                    &buffer,
1760                    cx
1761                ),
1762                vec![(2..5, "REM".into()), (9..11, "".into())]
1763            );
1764
1765            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1766            assert_eq!(
1767                from_completion_edits(
1768                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1769                    &buffer,
1770                    cx
1771                ),
1772                vec![(2..2, "REM".into()), (6..8, "".into())]
1773            );
1774
1775            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1776            assert_eq!(
1777                from_completion_edits(
1778                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1779                    &buffer,
1780                    cx
1781                ),
1782                vec![(2..5, "REM".into()), (9..11, "".into())]
1783            );
1784
1785            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1786            assert_eq!(
1787                from_completion_edits(
1788                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1789                    &buffer,
1790                    cx
1791                ),
1792                vec![(3..3, "EM".into()), (7..9, "".into())]
1793            );
1794
1795            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1796            assert_eq!(
1797                from_completion_edits(
1798                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1799                    &buffer,
1800                    cx
1801                ),
1802                vec![(4..4, "M".into()), (8..10, "".into())]
1803            );
1804
1805            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1806            assert_eq!(
1807                from_completion_edits(
1808                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1809                    &buffer,
1810                    cx
1811                ),
1812                vec![(9..11, "".into())]
1813            );
1814
1815            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1816            assert_eq!(
1817                from_completion_edits(
1818                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1819                    &buffer,
1820                    cx
1821                ),
1822                vec![(4..4, "M".into()), (8..10, "".into())]
1823            );
1824
1825            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1826            assert_eq!(
1827                from_completion_edits(
1828                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1829                    &buffer,
1830                    cx
1831                ),
1832                vec![(4..4, "M".into())]
1833            );
1834
1835            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1836            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1837        })
1838    }
1839
1840    #[gpui::test]
1841    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1842        init_test(cx);
1843
1844        assert_eq!(
1845            apply_edit_prediction(
1846                indoc! {"
1847                    fn main() {
1848                        let word_1 = \"lorem\";
1849                        let range = word.len()..word.len();
1850                    }
1851                "},
1852                indoc! {"
1853                    <|editable_region_start|>
1854                    fn main() {
1855                        let word_1 = \"lorem\";
1856                        let range = word_1.len()..word_1.len();
1857                    }
1858
1859                    <|editable_region_end|>
1860                "},
1861                cx,
1862            )
1863            .await,
1864            indoc! {"
1865                fn main() {
1866                    let word_1 = \"lorem\";
1867                    let range = word_1.len()..word_1.len();
1868                }
1869            "},
1870        );
1871
1872        assert_eq!(
1873            apply_edit_prediction(
1874                indoc! {"
1875                    fn main() {
1876                        let story = \"the quick\"
1877                    }
1878                "},
1879                indoc! {"
1880                    <|editable_region_start|>
1881                    fn main() {
1882                        let story = \"the quick brown fox jumps over the lazy dog\";
1883                    }
1884
1885                    <|editable_region_end|>
1886                "},
1887                cx,
1888            )
1889            .await,
1890            indoc! {"
1891                fn main() {
1892                    let story = \"the quick brown fox jumps over the lazy dog\";
1893                }
1894            "},
1895        );
1896    }
1897
1898    #[gpui::test]
1899    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1900        init_test(cx);
1901
1902        let buffer_content = "lorem\n";
1903        let completion_response = indoc! {"
1904            ```animals.js
1905            <|start_of_file|>
1906            <|editable_region_start|>
1907            lorem
1908            ipsum
1909            <|editable_region_end|>
1910            ```"};
1911
1912        assert_eq!(
1913            apply_edit_prediction(buffer_content, completion_response, cx).await,
1914            "lorem\nipsum"
1915        );
1916    }
1917
1918    #[gpui::test]
1919    async fn test_can_collect_data(cx: &mut TestAppContext) {
1920        init_test(cx);
1921
1922        let fs = project::FakeFs::new(cx.executor());
1923        fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1924            .await;
1925
1926        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1927        let buffer = project
1928            .update(cx, |project, cx| {
1929                project.open_local_buffer(path!("/project/src/main.rs"), cx)
1930            })
1931            .await
1932            .unwrap();
1933
1934        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1935        zeta.update(cx, |zeta, _cx| {
1936            zeta.data_collection_choice = DataCollectionChoice::Enabled
1937        });
1938
1939        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1940        assert_eq!(
1941            captured_request.lock().clone().unwrap().can_collect_data,
1942            true
1943        );
1944
1945        zeta.update(cx, |zeta, _cx| {
1946            zeta.data_collection_choice = DataCollectionChoice::Disabled
1947        });
1948
1949        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1950        assert_eq!(
1951            captured_request.lock().clone().unwrap().can_collect_data,
1952            false
1953        );
1954    }
1955
1956    #[gpui::test]
1957    async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1958        init_test(cx);
1959
1960        let fs = project::FakeFs::new(cx.executor());
1961        let project = Project::test(fs.clone(), [], cx).await;
1962
1963        let buffer = cx.new(|_cx| {
1964            Buffer::remote(
1965                language::BufferId::new(1).unwrap(),
1966                ReplicaId::new(1),
1967                language::Capability::ReadWrite,
1968                "fn main() {\n    println!(\"Hello\");\n}",
1969            )
1970        });
1971
1972        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1973        zeta.update(cx, |zeta, _cx| {
1974            zeta.data_collection_choice = DataCollectionChoice::Enabled
1975        });
1976
1977        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1978        assert_eq!(
1979            captured_request.lock().clone().unwrap().can_collect_data,
1980            false
1981        );
1982    }
1983
1984    #[gpui::test]
1985    async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1986        init_test(cx);
1987
1988        let fs = project::FakeFs::new(cx.executor());
1989        fs.insert_tree(
1990            path!("/project"),
1991            json!({
1992                "LICENSE": BSD_0_TXT,
1993                ".env": "SECRET_KEY=secret"
1994            }),
1995        )
1996        .await;
1997
1998        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1999        let buffer = project
2000            .update(cx, |project, cx| {
2001                project.open_local_buffer("/project/.env", cx)
2002            })
2003            .await
2004            .unwrap();
2005
2006        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2007        zeta.update(cx, |zeta, _cx| {
2008            zeta.data_collection_choice = DataCollectionChoice::Enabled
2009        });
2010
2011        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2012        assert_eq!(
2013            captured_request.lock().clone().unwrap().can_collect_data,
2014            false
2015        );
2016    }
2017
2018    #[gpui::test]
2019    async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2020        init_test(cx);
2021
2022        let fs = project::FakeFs::new(cx.executor());
2023        let project = Project::test(fs.clone(), [], cx).await;
2024        let buffer = cx.new(|cx| Buffer::local("", cx));
2025
2026        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2027        zeta.update(cx, |zeta, _cx| {
2028            zeta.data_collection_choice = DataCollectionChoice::Enabled
2029        });
2030
2031        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2032        assert_eq!(
2033            captured_request.lock().clone().unwrap().can_collect_data,
2034            false
2035        );
2036    }
2037
2038    #[gpui::test]
2039    async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2040        init_test(cx);
2041
2042        let fs = project::FakeFs::new(cx.executor());
2043        fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2044            .await;
2045
2046        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2047        let buffer = project
2048            .update(cx, |project, cx| {
2049                project.open_local_buffer("/project/main.rs", cx)
2050            })
2051            .await
2052            .unwrap();
2053
2054        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2055        zeta.update(cx, |zeta, _cx| {
2056            zeta.data_collection_choice = DataCollectionChoice::Enabled
2057        });
2058
2059        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2060        assert_eq!(
2061            captured_request.lock().clone().unwrap().can_collect_data,
2062            false
2063        );
2064    }
2065
2066    #[gpui::test]
2067    async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2068        init_test(cx);
2069
2070        let fs = project::FakeFs::new(cx.executor());
2071        fs.insert_tree(
2072            path!("/open_source_worktree"),
2073            json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2074        )
2075        .await;
2076        fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2077            .await;
2078
2079        let project = Project::test(
2080            fs.clone(),
2081            [
2082                path!("/open_source_worktree").as_ref(),
2083                path!("/closed_source_worktree").as_ref(),
2084            ],
2085            cx,
2086        )
2087        .await;
2088        let buffer = project
2089            .update(cx, |project, cx| {
2090                project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2091            })
2092            .await
2093            .unwrap();
2094
2095        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2096        zeta.update(cx, |zeta, _cx| {
2097            zeta.data_collection_choice = DataCollectionChoice::Enabled
2098        });
2099
2100        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2101        assert_eq!(
2102            captured_request.lock().clone().unwrap().can_collect_data,
2103            true
2104        );
2105
2106        let closed_source_file = project
2107            .update(cx, |project, cx| {
2108                let worktree2 = project
2109                    .worktree_for_root_name("closed_source_worktree", cx)
2110                    .unwrap();
2111                worktree2.update(cx, |worktree2, cx| {
2112                    worktree2.load_file(rel_path("main.rs"), cx)
2113                })
2114            })
2115            .await
2116            .unwrap()
2117            .file;
2118
2119        buffer.update(cx, |buffer, cx| {
2120            buffer.file_updated(closed_source_file, cx);
2121        });
2122
2123        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2124        assert_eq!(
2125            captured_request.lock().clone().unwrap().can_collect_data,
2126            false
2127        );
2128    }
2129
2130    #[gpui::test]
2131    async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2132        init_test(cx);
2133
2134        let fs = project::FakeFs::new(cx.executor());
2135        fs.insert_tree(
2136            path!("/worktree1"),
2137            json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2138        )
2139        .await;
2140        fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2141            .await;
2142
2143        let project = Project::test(
2144            fs.clone(),
2145            [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2146            cx,
2147        )
2148        .await;
2149        let buffer = project
2150            .update(cx, |project, cx| {
2151                project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2152            })
2153            .await
2154            .unwrap();
2155        let private_buffer = project
2156            .update(cx, |project, cx| {
2157                project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2158            })
2159            .await
2160            .unwrap();
2161
2162        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2163        zeta.update(cx, |zeta, _cx| {
2164            zeta.data_collection_choice = DataCollectionChoice::Enabled
2165        });
2166
2167        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2168        assert_eq!(
2169            captured_request.lock().clone().unwrap().can_collect_data,
2170            true
2171        );
2172
2173        // this has a side effect of registering the buffer to watch for edits
2174        run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2175        assert_eq!(
2176            captured_request.lock().clone().unwrap().can_collect_data,
2177            false
2178        );
2179
2180        private_buffer.update(cx, |private_buffer, cx| {
2181            private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2182        });
2183
2184        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2185        assert_eq!(
2186            captured_request.lock().clone().unwrap().can_collect_data,
2187            false
2188        );
2189
2190        // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2191        // included
2192        buffer.update(cx, |buffer, cx| {
2193            buffer.edit(
2194                [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2195                None,
2196                cx,
2197            );
2198        });
2199
2200        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2201        assert_eq!(
2202            captured_request.lock().clone().unwrap().can_collect_data,
2203            true
2204        );
2205    }
2206
2207    fn init_test(cx: &mut TestAppContext) {
2208        cx.update(|cx| {
2209            let settings_store = SettingsStore::test(cx);
2210            cx.set_global(settings_store);
2211        });
2212    }
2213
2214    async fn apply_edit_prediction(
2215        buffer_content: &str,
2216        completion_response: &str,
2217        cx: &mut TestAppContext,
2218    ) -> String {
2219        let fs = project::FakeFs::new(cx.executor());
2220        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2221        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2222        let (zeta, _, response) = make_test_zeta(&project, cx).await;
2223        *response.lock() = completion_response.to_string();
2224        let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2225        buffer.update(cx, |buffer, cx| {
2226            buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2227        });
2228        buffer.read_with(cx, |buffer, _| buffer.text())
2229    }
2230
2231    async fn run_edit_prediction(
2232        buffer: &Entity<Buffer>,
2233        project: &Entity<Project>,
2234        zeta: &Entity<Zeta>,
2235        cx: &mut TestAppContext,
2236    ) -> EditPrediction {
2237        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2238        zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2239        cx.background_executor.run_until_parked();
2240        let completion_task = zeta.update(cx, |zeta, cx| {
2241            zeta.request_completion(&project, buffer, cursor, cx)
2242        });
2243        completion_task.await.unwrap().unwrap()
2244    }
2245
2246    async fn make_test_zeta(
2247        project: &Entity<Project>,
2248        cx: &mut TestAppContext,
2249    ) -> (
2250        Entity<Zeta>,
2251        Arc<Mutex<Option<PredictEditsBody>>>,
2252        Arc<Mutex<String>>,
2253    ) {
2254        let default_response = indoc! {"
2255            ```main.rs
2256            <|start_of_file|>
2257            <|editable_region_start|>
2258            hello world
2259            <|editable_region_end|>
2260            ```"
2261        };
2262        let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2263        let completion_response: Arc<Mutex<String>> =
2264            Arc::new(Mutex::new(default_response.to_string()));
2265        let http_client = FakeHttpClient::create({
2266            let captured_request = captured_request.clone();
2267            let completion_response = completion_response.clone();
2268            move |req| {
2269                let captured_request = captured_request.clone();
2270                let completion_response = completion_response.clone();
2271                async move {
2272                    match (req.method(), req.uri().path()) {
2273                        (&Method::POST, "/client/llm_tokens") => {
2274                            Ok(http_client::Response::builder()
2275                                .status(200)
2276                                .body(
2277                                    serde_json::to_string(&CreateLlmTokenResponse {
2278                                        token: LlmToken("the-llm-token".to_string()),
2279                                    })
2280                                    .unwrap()
2281                                    .into(),
2282                                )
2283                                .unwrap())
2284                        }
2285                        (&Method::POST, "/predict_edits/v2") => {
2286                            let mut request_body = String::new();
2287                            req.into_body().read_to_string(&mut request_body).await?;
2288                            *captured_request.lock() =
2289                                Some(serde_json::from_str(&request_body).unwrap());
2290                            Ok(http_client::Response::builder()
2291                                .status(200)
2292                                .body(
2293                                    serde_json::to_string(&PredictEditsResponse {
2294                                        request_id: Uuid::new_v4().to_string(),
2295                                        output_excerpt: completion_response.lock().clone(),
2296                                    })
2297                                    .unwrap()
2298                                    .into(),
2299                                )
2300                                .unwrap())
2301                        }
2302                        _ => Ok(http_client::Response::builder()
2303                            .status(404)
2304                            .body("Not Found".into())
2305                            .unwrap()),
2306                    }
2307                }
2308            }
2309        });
2310
2311        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2312        cx.update(|cx| {
2313            RefreshLlmTokenListener::register(client.clone(), cx);
2314        });
2315        let _server = FakeServer::for_client(42, &client, cx).await;
2316
2317        let zeta = cx.new(|cx| {
2318            let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2319
2320            let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2321            for worktree in worktrees {
2322                let worktree_id = worktree.read(cx).id();
2323                zeta.license_detection_watchers
2324                    .entry(worktree_id)
2325                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2326            }
2327
2328            zeta
2329        });
2330
2331        (zeta, captured_request, completion_response)
2332    }
2333
2334    fn to_completion_edits(
2335        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2336        buffer: &Entity<Buffer>,
2337        cx: &App,
2338    ) -> Vec<(Range<Anchor>, Arc<str>)> {
2339        let buffer = buffer.read(cx);
2340        iterator
2341            .into_iter()
2342            .map(|(range, text)| {
2343                (
2344                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2345                    text,
2346                )
2347            })
2348            .collect()
2349    }
2350
2351    fn from_completion_edits(
2352        editor_edits: &[(Range<Anchor>, Arc<str>)],
2353        buffer: &Entity<Buffer>,
2354        cx: &App,
2355    ) -> Vec<(Range<usize>, Arc<str>)> {
2356        let buffer = buffer.read(cx);
2357        editor_edits
2358            .iter()
2359            .map(|(range, text)| {
2360                (
2361                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2362                    text.clone(),
2363                )
2364            })
2365            .collect()
2366    }
2367
2368    #[ctor::ctor]
2369    fn init_logger() {
2370        zlog::init_test();
2371    }
2372}