zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use db::smol::stream::StreamExt as _;
  12use edit_prediction::DataCollectionState;
  13use futures::channel::mpsc;
  14pub use init::*;
  15use license_detection::LicenseDetectionWatcher;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use arrayvec::ArrayVec;
  20use client::{Client, EditPredictionUsage, UserStore};
  21use cloud_llm_client::{
  22    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
  23    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  24    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
  25    ZED_VERSION_HEADER_NAME,
  26};
  27use collections::{HashMap, HashSet, VecDeque};
  28use futures::AsyncReadExt;
  29use gpui::{
  30    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  31    SharedString, Subscription, Task, actions,
  32};
  33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  34use input_excerpt::excerpt_for_cursor_position;
  35use language::{
  36    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  37};
  38use language_model::{LlmApiToken, RefreshLlmTokenListener};
  39use project::{Project, ProjectPath};
  40use release_channel::AppVersion;
  41use settings::WorktreeId;
  42use std::collections::hash_map;
  43use std::mem;
  44use std::str::FromStr;
  45use std::{
  46    cmp,
  47    fmt::Write,
  48    future::Future,
  49    ops::Range,
  50    path::Path,
  51    rc::Rc,
  52    sync::Arc,
  53    time::{Duration, Instant},
  54};
  55use telemetry_events::EditPredictionRating;
  56use thiserror::Error;
  57use util::ResultExt;
  58use util::rel_path::RelPath;
  59use uuid::Uuid;
  60use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  61use worktree::Worktree;
  62
  63const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  64const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  65const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  66const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  67const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  68const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  69
  70const MAX_CONTEXT_TOKENS: usize = 150;
  71const MAX_REWRITE_TOKENS: usize = 350;
  72const MAX_EVENT_TOKENS: usize = 500;
  73
  74/// Maximum number of events to track.
  75const MAX_EVENT_COUNT: usize = 16;
  76
  77actions!(
  78    edit_prediction,
  79    [
  80        /// Clears the edit prediction history.
  81        ClearHistory
  82    ]
  83);
  84
  85#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  86pub struct EditPredictionId(Uuid);
  87
  88impl From<EditPredictionId> for gpui::ElementId {
  89    fn from(value: EditPredictionId) -> Self {
  90        gpui::ElementId::Uuid(value.0)
  91    }
  92}
  93
  94impl std::fmt::Display for EditPredictionId {
  95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  96        write!(f, "{}", self.0)
  97    }
  98}
  99
 100struct ZedPredictUpsell;
 101
 102impl Dismissable for ZedPredictUpsell {
 103    const KEY: &'static str = "dismissed-edit-predict-upsell";
 104
 105    fn dismissed() -> bool {
 106        // To make this backwards compatible with older versions of Zed, we
 107        // check if the user has seen the previous Edit Prediction Onboarding
 108        // before, by checking the data collection choice which was written to
 109        // the database once the user clicked on "Accept and Enable"
 110        if KEY_VALUE_STORE
 111            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 112            .log_err()
 113            .is_some_and(|s| s.is_some())
 114        {
 115            return true;
 116        }
 117
 118        KEY_VALUE_STORE
 119            .read_kvp(Self::KEY)
 120            .log_err()
 121            .is_some_and(|s| s.is_some())
 122    }
 123}
 124
 125pub fn should_show_upsell_modal() -> bool {
 126    !ZedPredictUpsell::dismissed()
 127}
 128
 129#[derive(Clone)]
 130struct ZetaGlobal(Entity<Zeta>);
 131
 132impl Global for ZetaGlobal {}
 133
 134#[derive(Clone)]
 135pub struct EditPrediction {
 136    id: EditPredictionId,
 137    path: Arc<Path>,
 138    excerpt_range: Range<usize>,
 139    cursor_offset: usize,
 140    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 141    snapshot: BufferSnapshot,
 142    edit_preview: EditPreview,
 143    input_outline: Arc<str>,
 144    input_events: Arc<str>,
 145    input_excerpt: Arc<str>,
 146    output_excerpt: Arc<str>,
 147    buffer_snapshotted_at: Instant,
 148    response_received_at: Instant,
 149}
 150
 151impl EditPrediction {
 152    fn latency(&self) -> Duration {
 153        self.response_received_at
 154            .duration_since(self.buffer_snapshotted_at)
 155    }
 156
 157    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 158        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 159    }
 160}
 161
 162impl std::fmt::Debug for EditPrediction {
 163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 164        f.debug_struct("EditPrediction")
 165            .field("id", &self.id)
 166            .field("path", &self.path)
 167            .field("edits", &self.edits)
 168            .finish_non_exhaustive()
 169    }
 170}
 171
 172pub struct Zeta {
 173    projects: HashMap<EntityId, ZetaProject>,
 174    client: Arc<Client>,
 175    shown_completions: VecDeque<EditPrediction>,
 176    rated_completions: HashSet<EditPredictionId>,
 177    data_collection_choice: DataCollectionChoice,
 178    discarded_completions: Vec<EditPredictionRejection>,
 179    llm_token: LlmApiToken,
 180    _llm_token_subscription: Subscription,
 181    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 182    update_required: bool,
 183    user_store: Entity<UserStore>,
 184    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 185    discard_completions_debounce_task: Option<Task<()>>,
 186    discard_completions_tx: mpsc::UnboundedSender<()>,
 187}
 188
 189struct ZetaProject {
 190    events: VecDeque<Event>,
 191    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 192}
 193
 194impl Zeta {
 195    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 196        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 197    }
 198
 199    pub fn register(
 200        worktree: Option<Entity<Worktree>>,
 201        client: Arc<Client>,
 202        user_store: Entity<UserStore>,
 203        cx: &mut App,
 204    ) -> Entity<Self> {
 205        let this = Self::global(cx).unwrap_or_else(|| {
 206            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 207            cx.set_global(ZetaGlobal(entity.clone()));
 208            entity
 209        });
 210
 211        this.update(cx, move |this, cx| {
 212            if let Some(worktree) = worktree {
 213                let worktree_id = worktree.read(cx).id();
 214                this.license_detection_watchers
 215                    .entry(worktree_id)
 216                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 217            }
 218        });
 219
 220        this
 221    }
 222
 223    pub fn clear_history(&mut self) {
 224        for zeta_project in self.projects.values_mut() {
 225            zeta_project.events.clear();
 226        }
 227    }
 228
 229    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 230        self.user_store.read(cx).edit_prediction_usage()
 231    }
 232
 233    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 234        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 235        let data_collection_choice = Self::load_data_collection_choice();
 236        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 237        cx.spawn(async move |this, cx| {
 238            while let Some(()) = reject_rx.next().await {
 239                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 240                    .await
 241                    .log_err();
 242            }
 243            anyhow::Ok(())
 244        })
 245        .detach();
 246
 247        Self {
 248            projects: HashMap::default(),
 249            client,
 250            shown_completions: VecDeque::new(),
 251            rated_completions: HashSet::default(),
 252            discarded_completions: Vec::new(),
 253            discard_completions_debounce_task: None,
 254            discard_completions_tx: reject_tx,
 255            data_collection_choice,
 256            llm_token: LlmApiToken::default(),
 257            _llm_token_subscription: cx.subscribe(
 258                &refresh_llm_token_listener,
 259                |this, _listener, _event, cx| {
 260                    let client = this.client.clone();
 261                    let llm_token = this.llm_token.clone();
 262                    cx.spawn(async move |_this, _cx| {
 263                        llm_token.refresh(&client).await?;
 264                        anyhow::Ok(())
 265                    })
 266                    .detach_and_log_err(cx);
 267                },
 268            ),
 269            update_required: false,
 270            license_detection_watchers: HashMap::default(),
 271            user_store,
 272        }
 273    }
 274
 275    fn get_or_init_zeta_project(
 276        &mut self,
 277        project: &Entity<Project>,
 278        cx: &mut Context<Self>,
 279    ) -> &mut ZetaProject {
 280        let project_id = project.entity_id();
 281        match self.projects.entry(project_id) {
 282            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 283            hash_map::Entry::Vacant(entry) => {
 284                cx.observe_release(project, move |this, _, _cx| {
 285                    this.projects.remove(&project_id);
 286                })
 287                .detach();
 288                entry.insert(ZetaProject {
 289                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 290                    registered_buffers: HashMap::default(),
 291                })
 292            }
 293        }
 294    }
 295
 296    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 297        let events = &mut zeta_project.events;
 298
 299        if let Some(Event::BufferChange {
 300            new_snapshot: last_new_snapshot,
 301            timestamp: last_timestamp,
 302            ..
 303        }) = events.back_mut()
 304        {
 305            // Coalesce edits for the same buffer when they happen one after the other.
 306            let Event::BufferChange {
 307                old_snapshot,
 308                new_snapshot,
 309                timestamp,
 310            } = &event;
 311
 312            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 313                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 314                && old_snapshot.version == last_new_snapshot.version
 315            {
 316                *last_new_snapshot = new_snapshot.clone();
 317                *last_timestamp = *timestamp;
 318                return;
 319            }
 320        }
 321
 322        if events.len() >= MAX_EVENT_COUNT {
 323            // These are halved instead of popping to improve prompt caching.
 324            events.drain(..MAX_EVENT_COUNT / 2);
 325        }
 326
 327        events.push_back(event);
 328    }
 329
 330    pub fn register_buffer(
 331        &mut self,
 332        buffer: &Entity<Buffer>,
 333        project: &Entity<Project>,
 334        cx: &mut Context<Self>,
 335    ) {
 336        let zeta_project = self.get_or_init_zeta_project(project, cx);
 337        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 338    }
 339
 340    fn register_buffer_impl<'a>(
 341        zeta_project: &'a mut ZetaProject,
 342        buffer: &Entity<Buffer>,
 343        project: &Entity<Project>,
 344        cx: &mut Context<Self>,
 345    ) -> &'a mut RegisteredBuffer {
 346        let buffer_id = buffer.entity_id();
 347        match zeta_project.registered_buffers.entry(buffer_id) {
 348            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 349            hash_map::Entry::Vacant(entry) => {
 350                let snapshot = buffer.read(cx).snapshot();
 351                let project_entity_id = project.entity_id();
 352                entry.insert(RegisteredBuffer {
 353                    snapshot,
 354                    _subscriptions: [
 355                        cx.subscribe(buffer, {
 356                            let project = project.downgrade();
 357                            move |this, buffer, event, cx| {
 358                                if let language::BufferEvent::Edited = event
 359                                    && let Some(project) = project.upgrade()
 360                                {
 361                                    this.report_changes_for_buffer(&buffer, &project, cx);
 362                                }
 363                            }
 364                        }),
 365                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 366                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 367                            else {
 368                                return;
 369                            };
 370                            zeta_project.registered_buffers.remove(&buffer_id);
 371                        }),
 372                    ],
 373                })
 374            }
 375        }
 376    }
 377
 378    fn request_completion_impl<F, R>(
 379        &mut self,
 380        project: &Entity<Project>,
 381        buffer: &Entity<Buffer>,
 382        cursor: language::Anchor,
 383        cx: &mut Context<Self>,
 384        perform_predict_edits: F,
 385    ) -> Task<Result<Option<EditPrediction>>>
 386    where
 387        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 388        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 389            + Send
 390            + 'static,
 391    {
 392        let buffer = buffer.clone();
 393        let buffer_snapshotted_at = Instant::now();
 394        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 395        let zeta = cx.entity();
 396        let client = self.client.clone();
 397        let llm_token = self.llm_token.clone();
 398        let app_version = AppVersion::global(cx);
 399
 400        let zeta_project = self.get_or_init_zeta_project(project, cx);
 401        let mut events = Vec::with_capacity(zeta_project.events.len());
 402        events.extend(zeta_project.events.iter().cloned());
 403        let events = Arc::new(events);
 404
 405        let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 406            let can_collect_file = self.can_collect_file(file, cx);
 407            let git_info = if can_collect_file {
 408                git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 409            } else {
 410                None
 411            };
 412            (git_info, can_collect_file)
 413        } else {
 414            (None, false)
 415        };
 416
 417        let full_path: Arc<Path> = snapshot
 418            .file()
 419            .map(|f| Arc::from(f.full_path(cx).as_path()))
 420            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 421        let full_path_str = full_path.to_string_lossy().into_owned();
 422        let cursor_point = cursor.to_point(&snapshot);
 423        let cursor_offset = cursor_point.to_offset(&snapshot);
 424        let prompt_for_events = {
 425            let events = events.clone();
 426            move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 427        };
 428        let gather_task = gather_context(
 429            full_path_str,
 430            &snapshot,
 431            cursor_point,
 432            prompt_for_events,
 433            cx,
 434        );
 435
 436        cx.spawn(async move |this, cx| {
 437            let GatherContextOutput {
 438                mut body,
 439                editable_range,
 440                included_events_count,
 441            } = gather_task.await?;
 442            let done_gathering_context_at = Instant::now();
 443
 444            let included_events = &events[events.len() - included_events_count..events.len()];
 445            body.can_collect_data = can_collect_file
 446                && this
 447                    .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
 448                    .unwrap_or(false);
 449            if body.can_collect_data {
 450                body.git_info = git_info;
 451            }
 452
 453            log::debug!(
 454                "Events:\n{}\nExcerpt:\n{:?}",
 455                body.input_events,
 456                body.input_excerpt
 457            );
 458
 459            let input_outline = body.outline.clone().unwrap_or_default();
 460            let input_events = body.input_events.clone();
 461            let input_excerpt = body.input_excerpt.clone();
 462
 463            let response = perform_predict_edits(PerformPredictEditsParams {
 464                client,
 465                llm_token,
 466                app_version,
 467                body,
 468            })
 469            .await;
 470            let (response, usage) = match response {
 471                Ok(response) => response,
 472                Err(err) => {
 473                    if err.is::<ZedUpdateRequiredError>() {
 474                        cx.update(|cx| {
 475                            zeta.update(cx, |zeta, _cx| {
 476                                zeta.update_required = true;
 477                            });
 478
 479                            let error_message: SharedString = err.to_string().into();
 480                            show_app_notification(
 481                                NotificationId::unique::<ZedUpdateRequiredError>(),
 482                                cx,
 483                                move |cx| {
 484                                    cx.new(|cx| {
 485                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 486                                            .with_link_button(
 487                                                "Update Zed",
 488                                                "https://zed.dev/releases",
 489                                            )
 490                                    })
 491                                },
 492                            );
 493                        })
 494                        .ok();
 495                    }
 496
 497                    return Err(err);
 498                }
 499            };
 500
 501            let received_response_at = Instant::now();
 502            log::debug!("completion response: {}", &response.output_excerpt);
 503
 504            if let Some(usage) = usage {
 505                this.update(cx, |this, cx| {
 506                    this.user_store.update(cx, |user_store, cx| {
 507                        user_store.update_edit_prediction_usage(usage, cx);
 508                    });
 509                })
 510                .ok();
 511            }
 512
 513            let edit_prediction = Self::process_completion_response(
 514                response,
 515                buffer,
 516                &snapshot,
 517                editable_range,
 518                cursor_offset,
 519                full_path,
 520                input_outline,
 521                input_events,
 522                input_excerpt,
 523                buffer_snapshotted_at,
 524                cx,
 525            )
 526            .await;
 527
 528            let finished_at = Instant::now();
 529
 530            // record latency for ~1% of requests
 531            if rand::random::<u8>() <= 2 {
 532                telemetry::event!(
 533                    "Edit Prediction Request",
 534                    context_latency = done_gathering_context_at
 535                        .duration_since(buffer_snapshotted_at)
 536                        .as_millis(),
 537                    request_latency = received_response_at
 538                        .duration_since(done_gathering_context_at)
 539                        .as_millis(),
 540                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 541                );
 542            }
 543
 544            edit_prediction
 545        })
 546    }
 547
 548    #[cfg(any(test, feature = "test-support"))]
 549    pub fn fake_completion(
 550        &mut self,
 551        project: &Entity<Project>,
 552        buffer: &Entity<Buffer>,
 553        position: language::Anchor,
 554        response: PredictEditsResponse,
 555        cx: &mut Context<Self>,
 556    ) -> Task<Result<Option<EditPrediction>>> {
 557        self.request_completion_impl(project, buffer, position, cx, |_params| {
 558            std::future::ready(Ok((response, None)))
 559        })
 560    }
 561
 562    pub fn request_completion(
 563        &mut self,
 564        project: &Entity<Project>,
 565        buffer: &Entity<Buffer>,
 566        position: language::Anchor,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<Option<EditPrediction>>> {
 569        self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
 570    }
 571
 572    pub fn perform_predict_edits(
 573        params: PerformPredictEditsParams,
 574    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 575        async move {
 576            let PerformPredictEditsParams {
 577                client,
 578                llm_token,
 579                app_version,
 580                body,
 581                ..
 582            } = params;
 583
 584            let http_client = client.http_client();
 585            let mut token = llm_token.acquire(&client).await?;
 586            let mut did_retry = false;
 587
 588            loop {
 589                let request_builder = http_client::Request::builder().method(Method::POST);
 590                let request_builder =
 591                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 592                        request_builder.uri(predict_edits_url)
 593                    } else {
 594                        request_builder.uri(
 595                            http_client
 596                                .build_zed_llm_url("/predict_edits/v2", &[])?
 597                                .as_ref(),
 598                        )
 599                    };
 600                let request = request_builder
 601                    .header("Content-Type", "application/json")
 602                    .header("Authorization", format!("Bearer {}", token))
 603                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 604                    .body(serde_json::to_string(&body)?.into())?;
 605
 606                let mut response = http_client.send(request).await?;
 607
 608                if let Some(minimum_required_version) = response
 609                    .headers()
 610                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 611                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 612                {
 613                    anyhow::ensure!(
 614                        app_version >= minimum_required_version,
 615                        ZedUpdateRequiredError {
 616                            minimum_version: minimum_required_version
 617                        }
 618                    );
 619                }
 620
 621                if response.status().is_success() {
 622                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 623
 624                    let mut body = String::new();
 625                    response.body_mut().read_to_string(&mut body).await?;
 626                    return Ok((serde_json::from_str(&body)?, usage));
 627                } else if !did_retry
 628                    && response
 629                        .headers()
 630                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 631                        .is_some()
 632                {
 633                    did_retry = true;
 634                    token = llm_token.refresh(&client).await?;
 635                } else {
 636                    let mut body = String::new();
 637                    response.body_mut().read_to_string(&mut body).await?;
 638                    anyhow::bail!(
 639                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 640                        response.status(),
 641                        body
 642                    );
 643                }
 644            }
 645        }
 646    }
 647
 648    fn accept_edit_prediction(
 649        &mut self,
 650        request_id: EditPredictionId,
 651        cx: &mut Context<Self>,
 652    ) -> Task<Result<()>> {
 653        let client = self.client.clone();
 654        let llm_token = self.llm_token.clone();
 655        let app_version = AppVersion::global(cx);
 656        cx.spawn(async move |this, cx| {
 657            let http_client = client.http_client();
 658            let mut response = llm_token_retry(&llm_token, &client, |token| {
 659                let request_builder = http_client::Request::builder().method(Method::POST);
 660                let request_builder =
 661                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 662                        request_builder.uri(accept_prediction_url)
 663                    } else {
 664                        request_builder.uri(
 665                            http_client
 666                                .build_zed_llm_url("/predict_edits/accept", &[])?
 667                                .as_ref(),
 668                        )
 669                    };
 670                Ok(request_builder
 671                    .header("Content-Type", "application/json")
 672                    .header("Authorization", format!("Bearer {}", token))
 673                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 674                    .body(
 675                        serde_json::to_string(&AcceptEditPredictionBody {
 676                            request_id: request_id.0.to_string(),
 677                        })?
 678                        .into(),
 679                    )?)
 680            })
 681            .await?;
 682
 683            if let Some(minimum_required_version) = response
 684                .headers()
 685                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 686                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 687                && app_version < minimum_required_version
 688            {
 689                return Err(anyhow!(ZedUpdateRequiredError {
 690                    minimum_version: minimum_required_version
 691                }));
 692            }
 693
 694            if response.status().is_success() {
 695                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 696                    this.update(cx, |this, cx| {
 697                        this.user_store.update(cx, |user_store, cx| {
 698                            user_store.update_edit_prediction_usage(usage, cx);
 699                        });
 700                    })?;
 701                }
 702
 703                Ok(())
 704            } else {
 705                let mut body = String::new();
 706                response.body_mut().read_to_string(&mut body).await?;
 707                Err(anyhow!(
 708                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 709                    response.status(),
 710                    body
 711                ))
 712            }
 713        })
 714    }
 715
 716    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 717        let client = self.client.clone();
 718        let llm_token = self.llm_token.clone();
 719        let app_version = AppVersion::global(cx);
 720        let last_rejection = self.discarded_completions.last().cloned();
 721        let body = serde_json::to_string(&RejectEditPredictionsBody {
 722            rejections: self.discarded_completions.clone(),
 723        })
 724        .ok();
 725
 726        let Some(last_rejection) = last_rejection else {
 727            return Task::ready(anyhow::Ok(()));
 728        };
 729
 730        cx.spawn(async move |this, cx| {
 731            let http_client = client.http_client();
 732            let mut response = llm_token_retry(&llm_token, &client, |token| {
 733                let request_builder = http_client::Request::builder().method(Method::POST);
 734                let request_builder = request_builder.uri(
 735                    http_client
 736                        .build_zed_llm_url("/predict_edits/reject", &[])?
 737                        .as_ref(),
 738                );
 739                Ok(request_builder
 740                    .header("Content-Type", "application/json")
 741                    .header("Authorization", format!("Bearer {}", token))
 742                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 743                    .body(
 744                        body.as_ref()
 745                            .context("failed to serialize body")?
 746                            .clone()
 747                            .into(),
 748                    )?)
 749            })
 750            .await?;
 751
 752            if let Some(minimum_required_version) = response
 753                .headers()
 754                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 755                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 756                && app_version < minimum_required_version
 757            {
 758                return Err(anyhow!(ZedUpdateRequiredError {
 759                    minimum_version: minimum_required_version
 760                }));
 761            }
 762
 763            if response.status().is_success() {
 764                this.update(cx, |this, _| {
 765                    if let Some(ix) = this
 766                        .discarded_completions
 767                        .iter()
 768                        .position(|rejection| rejection.request_id == last_rejection.request_id)
 769                    {
 770                        this.discarded_completions.drain(..ix + 1);
 771                    }
 772                })
 773            } else {
 774                let mut body = String::new();
 775                response.body_mut().read_to_string(&mut body).await?;
 776                Err(anyhow!(
 777                    "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
 778                    response.status(),
 779                    body
 780                ))
 781            }
 782        })
 783    }
 784
 785    fn process_completion_response(
 786        prediction_response: PredictEditsResponse,
 787        buffer: Entity<Buffer>,
 788        snapshot: &BufferSnapshot,
 789        editable_range: Range<usize>,
 790        cursor_offset: usize,
 791        path: Arc<Path>,
 792        input_outline: String,
 793        input_events: String,
 794        input_excerpt: String,
 795        buffer_snapshotted_at: Instant,
 796        cx: &AsyncApp,
 797    ) -> Task<Result<Option<EditPrediction>>> {
 798        let snapshot = snapshot.clone();
 799        let request_id = prediction_response.request_id;
 800        let output_excerpt = prediction_response.output_excerpt;
 801        cx.spawn(async move |cx| {
 802            let output_excerpt: Arc<str> = output_excerpt.into();
 803
 804            let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
 805                .background_spawn({
 806                    let output_excerpt = output_excerpt.clone();
 807                    let editable_range = editable_range.clone();
 808                    let snapshot = snapshot.clone();
 809                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 810                })
 811                .await?
 812                .into();
 813
 814            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 815                let edits = edits.clone();
 816                move |buffer, cx| {
 817                    let new_snapshot = buffer.snapshot();
 818                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
 819                        edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
 820                            .into();
 821                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 822                }
 823            })?
 824            else {
 825                return anyhow::Ok(None);
 826            };
 827
 828            let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
 829
 830            let edit_preview = edit_preview.await;
 831
 832            Ok(Some(EditPrediction {
 833                id: EditPredictionId(request_id),
 834                path,
 835                excerpt_range: editable_range,
 836                cursor_offset,
 837                edits,
 838                edit_preview,
 839                snapshot,
 840                input_outline: input_outline.into(),
 841                input_events: input_events.into(),
 842                input_excerpt: input_excerpt.into(),
 843                output_excerpt,
 844                buffer_snapshotted_at,
 845                response_received_at: Instant::now(),
 846            }))
 847        })
 848    }
 849
 850    fn parse_edits(
 851        output_excerpt: Arc<str>,
 852        editable_range: Range<usize>,
 853        snapshot: &BufferSnapshot,
 854    ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
 855        let content = output_excerpt.replace(CURSOR_MARKER, "");
 856
 857        let start_markers = content
 858            .match_indices(EDITABLE_REGION_START_MARKER)
 859            .collect::<Vec<_>>();
 860        anyhow::ensure!(
 861            start_markers.len() == 1,
 862            "expected exactly one start marker, found {}",
 863            start_markers.len()
 864        );
 865
 866        let end_markers = content
 867            .match_indices(EDITABLE_REGION_END_MARKER)
 868            .collect::<Vec<_>>();
 869        anyhow::ensure!(
 870            end_markers.len() == 1,
 871            "expected exactly one end marker, found {}",
 872            end_markers.len()
 873        );
 874
 875        let sof_markers = content
 876            .match_indices(START_OF_FILE_MARKER)
 877            .collect::<Vec<_>>();
 878        anyhow::ensure!(
 879            sof_markers.len() <= 1,
 880            "expected at most one start-of-file marker, found {}",
 881            sof_markers.len()
 882        );
 883
 884        let codefence_start = start_markers[0].0;
 885        let content = &content[codefence_start..];
 886
 887        let newline_ix = content.find('\n').context("could not find newline")?;
 888        let content = &content[newline_ix + 1..];
 889
 890        let codefence_end = content
 891            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 892            .context("could not find end marker")?;
 893        let new_text = &content[..codefence_end];
 894
 895        let old_text = snapshot
 896            .text_for_range(editable_range.clone())
 897            .collect::<String>();
 898
 899        Ok(Self::compute_edits(
 900            old_text,
 901            new_text,
 902            editable_range.start,
 903            snapshot,
 904        ))
 905    }
 906
 907    pub fn compute_edits(
 908        old_text: String,
 909        new_text: &str,
 910        offset: usize,
 911        snapshot: &BufferSnapshot,
 912    ) -> Vec<(Range<Anchor>, Arc<str>)> {
 913        text_diff(&old_text, new_text)
 914            .into_iter()
 915            .map(|(mut old_range, new_text)| {
 916                old_range.start += offset;
 917                old_range.end += offset;
 918
 919                let prefix_len = common_prefix(
 920                    snapshot.chars_for_range(old_range.clone()),
 921                    new_text.chars(),
 922                );
 923                old_range.start += prefix_len;
 924
 925                let suffix_len = common_prefix(
 926                    snapshot.reversed_chars_for_range(old_range.clone()),
 927                    new_text[prefix_len..].chars().rev(),
 928                );
 929                old_range.end = old_range.end.saturating_sub(suffix_len);
 930
 931                let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
 932                let range = if old_range.is_empty() {
 933                    let anchor = snapshot.anchor_after(old_range.start);
 934                    anchor..anchor
 935                } else {
 936                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 937                };
 938                (range, new_text)
 939            })
 940            .collect()
 941    }
 942
 943    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 944        self.rated_completions.contains(&completion_id)
 945    }
 946
 947    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 948        self.shown_completions.push_front(completion.clone());
 949        if self.shown_completions.len() > 50 {
 950            let completion = self.shown_completions.pop_back().unwrap();
 951            self.rated_completions.remove(&completion.id);
 952        }
 953        cx.notify();
 954    }
 955
 956    pub fn rate_completion(
 957        &mut self,
 958        completion: &EditPrediction,
 959        rating: EditPredictionRating,
 960        feedback: String,
 961        cx: &mut Context<Self>,
 962    ) {
 963        self.rated_completions.insert(completion.id);
 964        telemetry::event!(
 965            "Edit Prediction Rated",
 966            rating,
 967            input_events = completion.input_events,
 968            input_excerpt = completion.input_excerpt,
 969            input_outline = completion.input_outline,
 970            output_excerpt = completion.output_excerpt,
 971            feedback
 972        );
 973        self.client.telemetry().flush_events().detach();
 974        cx.notify();
 975    }
 976
 977    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 978        self.shown_completions.iter()
 979    }
 980
 981    pub fn shown_completions_len(&self) -> usize {
 982        self.shown_completions.len()
 983    }
 984
 985    fn report_changes_for_buffer(
 986        &mut self,
 987        buffer: &Entity<Buffer>,
 988        project: &Entity<Project>,
 989        cx: &mut Context<Self>,
 990    ) -> BufferSnapshot {
 991        let zeta_project = self.get_or_init_zeta_project(project, cx);
 992        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 993
 994        let new_snapshot = buffer.read(cx).snapshot();
 995        if new_snapshot.version != registered_buffer.snapshot.version {
 996            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 997            Self::push_event(
 998                zeta_project,
 999                Event::BufferChange {
1000                    old_snapshot,
1001                    new_snapshot: new_snapshot.clone(),
1002                    timestamp: Instant::now(),
1003                },
1004            );
1005        }
1006
1007        new_snapshot
1008    }
1009
1010    fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1011        self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1012    }
1013
1014    fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1015        if !self.data_collection_choice.is_enabled() {
1016            return false;
1017        }
1018        let mut last_checked_file = None;
1019        for event in events {
1020            match event {
1021                Event::BufferChange {
1022                    old_snapshot,
1023                    new_snapshot,
1024                    ..
1025                } => {
1026                    if let Some(old_file) = old_snapshot.file()
1027                        && let Some(new_file) = new_snapshot.file()
1028                    {
1029                        if let Some(last_checked_file) = last_checked_file
1030                            && Arc::ptr_eq(last_checked_file, old_file)
1031                            && Arc::ptr_eq(last_checked_file, new_file)
1032                        {
1033                            continue;
1034                        }
1035                        if !self.can_collect_file(old_file, cx) {
1036                            return false;
1037                        }
1038                        if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1039                        {
1040                            return false;
1041                        }
1042                        last_checked_file = Some(new_file);
1043                    } else {
1044                        return false;
1045                    }
1046                }
1047            }
1048        }
1049        true
1050    }
1051
1052    fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1053        if !file.is_local() || file.is_private() {
1054            return false;
1055        }
1056        self.license_detection_watchers
1057            .get(&file.worktree_id(cx))
1058            .is_some_and(|watcher| watcher.is_project_open_source())
1059    }
1060
1061    fn load_data_collection_choice() -> DataCollectionChoice {
1062        let choice = KEY_VALUE_STORE
1063            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1064            .log_err()
1065            .flatten();
1066
1067        match choice.as_deref() {
1068            Some("true") => DataCollectionChoice::Enabled,
1069            Some("false") => DataCollectionChoice::Disabled,
1070            Some(_) => {
1071                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1072                DataCollectionChoice::NotAnswered
1073            }
1074            None => DataCollectionChoice::NotAnswered,
1075        }
1076    }
1077
1078    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1079        self.data_collection_choice = self.data_collection_choice.toggle();
1080        let new_choice = self.data_collection_choice;
1081        db::write_and_log(cx, move || {
1082            KEY_VALUE_STORE.write_kvp(
1083                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1084                new_choice.is_enabled().to_string(),
1085            )
1086        });
1087    }
1088
1089    fn discard_completion(
1090        &mut self,
1091        completion_id: EditPredictionId,
1092        was_shown: bool,
1093        cx: &mut Context<Self>,
1094    ) {
1095        self.discarded_completions.push(EditPredictionRejection {
1096            request_id: completion_id.to_string(),
1097            was_shown,
1098        });
1099
1100        let reached_request_limit =
1101            self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1102        let discard_completions_tx = self.discard_completions_tx.clone();
1103        self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1104            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1105            if !reached_request_limit {
1106                cx.background_executor()
1107                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1108                    .await;
1109            }
1110            discard_completions_tx.unbounded_send(()).log_err();
1111        }));
1112    }
1113}
1114
1115pub struct PerformPredictEditsParams {
1116    pub client: Arc<Client>,
1117    pub llm_token: LlmApiToken,
1118    pub app_version: SemanticVersion,
1119    pub body: PredictEditsBody,
1120}
1121
1122#[derive(Error, Debug)]
1123#[error(
1124    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1125)]
1126pub struct ZedUpdateRequiredError {
1127    minimum_version: SemanticVersion,
1128}
1129
1130fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1131    a.zip(b)
1132        .take_while(|(a, b)| a == b)
1133        .map(|(a, _)| a.len_utf8())
1134        .sum()
1135}
1136
1137fn git_info_for_file(
1138    project: &Entity<Project>,
1139    project_path: &ProjectPath,
1140    cx: &App,
1141) -> Option<PredictEditsGitInfo> {
1142    let git_store = project.read(cx).git_store().read(cx);
1143    if let Some((repository, _repo_path)) =
1144        git_store.repository_and_path_for_project_path(project_path, cx)
1145    {
1146        let repository = repository.read(cx);
1147        let head_sha = repository
1148            .head_commit
1149            .as_ref()
1150            .map(|head_commit| head_commit.sha.to_string());
1151        let remote_origin_url = repository.remote_origin_url.clone();
1152        let remote_upstream_url = repository.remote_upstream_url.clone();
1153        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1154            return None;
1155        }
1156        Some(PredictEditsGitInfo {
1157            head_sha,
1158            remote_origin_url,
1159            remote_upstream_url,
1160        })
1161    } else {
1162        None
1163    }
1164}
1165
1166pub struct GatherContextOutput {
1167    pub body: PredictEditsBody,
1168    pub editable_range: Range<usize>,
1169    pub included_events_count: usize,
1170}
1171
1172pub fn gather_context(
1173    full_path_str: String,
1174    snapshot: &BufferSnapshot,
1175    cursor_point: language::Point,
1176    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1177    cx: &App,
1178) -> Task<Result<GatherContextOutput>> {
1179    cx.background_spawn({
1180        let snapshot = snapshot.clone();
1181        async move {
1182            let input_excerpt = excerpt_for_cursor_position(
1183                cursor_point,
1184                &full_path_str,
1185                &snapshot,
1186                MAX_REWRITE_TOKENS,
1187                MAX_CONTEXT_TOKENS,
1188            );
1189            let (input_events, included_events_count) = prompt_for_events();
1190            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1191
1192            let body = PredictEditsBody {
1193                input_events,
1194                input_excerpt: input_excerpt.prompt,
1195                can_collect_data: false,
1196                diagnostic_groups: None,
1197                git_info: None,
1198                outline: None,
1199                speculated_output: None,
1200            };
1201
1202            Ok(GatherContextOutput {
1203                body,
1204                editable_range,
1205                included_events_count,
1206            })
1207        }
1208    })
1209}
1210
1211fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1212    let mut result = String::new();
1213    for (ix, event) in events.iter().rev().enumerate() {
1214        let event_string = event.to_prompt();
1215        let event_tokens = guess_token_count(event_string.len());
1216        if event_tokens > remaining_tokens {
1217            return (result, ix);
1218        }
1219
1220        if !result.is_empty() {
1221            result.insert_str(0, "\n\n");
1222        }
1223        result.insert_str(0, &event_string);
1224        remaining_tokens -= event_tokens;
1225    }
1226    return (result, events.len());
1227}
1228
1229struct RegisteredBuffer {
1230    snapshot: BufferSnapshot,
1231    _subscriptions: [gpui::Subscription; 2],
1232}
1233
1234#[derive(Clone)]
1235pub enum Event {
1236    BufferChange {
1237        old_snapshot: BufferSnapshot,
1238        new_snapshot: BufferSnapshot,
1239        timestamp: Instant,
1240    },
1241}
1242
1243impl Event {
1244    fn to_prompt(&self) -> String {
1245        match self {
1246            Event::BufferChange {
1247                old_snapshot,
1248                new_snapshot,
1249                ..
1250            } => {
1251                let mut prompt = String::new();
1252
1253                let old_path = old_snapshot
1254                    .file()
1255                    .map(|f| f.path().as_ref())
1256                    .unwrap_or(RelPath::unix("untitled").unwrap());
1257                let new_path = new_snapshot
1258                    .file()
1259                    .map(|f| f.path().as_ref())
1260                    .unwrap_or(RelPath::unix("untitled").unwrap());
1261                if old_path != new_path {
1262                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1263                }
1264
1265                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1266                if !diff.is_empty() {
1267                    write!(
1268                        prompt,
1269                        "User edited {:?}:\n```diff\n{}\n```",
1270                        new_path, diff
1271                    )
1272                    .unwrap();
1273                }
1274
1275                prompt
1276            }
1277        }
1278    }
1279}
1280
1281#[derive(Debug, Clone)]
1282struct CurrentEditPrediction {
1283    buffer_id: EntityId,
1284    completion: EditPrediction,
1285    was_shown: bool,
1286    was_accepted: bool,
1287}
1288
1289impl CurrentEditPrediction {
1290    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1291        if self.buffer_id != old_completion.buffer_id {
1292            return true;
1293        }
1294
1295        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1296            return true;
1297        };
1298        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1299            return false;
1300        };
1301
1302        if old_edits.len() == 1 && new_edits.len() == 1 {
1303            let (old_range, old_text) = &old_edits[0];
1304            let (new_range, new_text) = &new_edits[0];
1305            new_range == old_range && new_text.starts_with(old_text.as_ref())
1306        } else {
1307            true
1308        }
1309    }
1310}
1311
1312struct PendingCompletion {
1313    id: usize,
1314    task: Task<()>,
1315}
1316
1317#[derive(Debug, Clone, Copy)]
1318pub enum DataCollectionChoice {
1319    NotAnswered,
1320    Enabled,
1321    Disabled,
1322}
1323
1324impl DataCollectionChoice {
1325    pub fn is_enabled(self) -> bool {
1326        match self {
1327            Self::Enabled => true,
1328            Self::NotAnswered | Self::Disabled => false,
1329        }
1330    }
1331
1332    pub fn is_answered(self) -> bool {
1333        match self {
1334            Self::Enabled | Self::Disabled => true,
1335            Self::NotAnswered => false,
1336        }
1337    }
1338
1339    #[must_use]
1340    pub fn toggle(&self) -> DataCollectionChoice {
1341        match self {
1342            Self::Enabled => Self::Disabled,
1343            Self::Disabled => Self::Enabled,
1344            Self::NotAnswered => Self::Enabled,
1345        }
1346    }
1347}
1348
1349impl From<bool> for DataCollectionChoice {
1350    fn from(value: bool) -> Self {
1351        match value {
1352            true => DataCollectionChoice::Enabled,
1353            false => DataCollectionChoice::Disabled,
1354        }
1355    }
1356}
1357
1358async fn llm_token_retry(
1359    llm_token: &LlmApiToken,
1360    client: &Arc<Client>,
1361    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1362) -> Result<Response<AsyncBody>> {
1363    let mut did_retry = false;
1364    let http_client = client.http_client();
1365    let mut token = llm_token.acquire(client).await?;
1366    loop {
1367        let request = build_request(token.clone())?;
1368        let response = http_client.send(request).await?;
1369
1370        if !did_retry
1371            && !response.status().is_success()
1372            && response
1373                .headers()
1374                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1375                .is_some()
1376        {
1377            did_retry = true;
1378            token = llm_token.refresh(client).await?;
1379            continue;
1380        }
1381
1382        return Ok(response);
1383    }
1384}
1385
1386pub struct ZetaEditPredictionProvider {
1387    zeta: Entity<Zeta>,
1388    singleton_buffer: Option<Entity<Buffer>>,
1389    pending_completions: ArrayVec<PendingCompletion, 2>,
1390    canceled_completions: HashMap<usize, Task<()>>,
1391    next_pending_completion_id: usize,
1392    current_completion: Option<CurrentEditPrediction>,
1393    last_request_timestamp: Instant,
1394    project: Entity<Project>,
1395}
1396
1397impl ZetaEditPredictionProvider {
1398    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1399
1400    pub fn new(
1401        zeta: Entity<Zeta>,
1402        project: Entity<Project>,
1403        singleton_buffer: Option<Entity<Buffer>>,
1404        cx: &mut Context<Self>,
1405    ) -> Self {
1406        cx.on_release(|this, cx| {
1407            this.take_current_edit_prediction(cx);
1408        })
1409        .detach();
1410
1411        Self {
1412            zeta,
1413            singleton_buffer,
1414            pending_completions: ArrayVec::new(),
1415            canceled_completions: HashMap::default(),
1416            next_pending_completion_id: 0,
1417            current_completion: None,
1418            last_request_timestamp: Instant::now(),
1419            project,
1420        }
1421    }
1422
1423    fn take_current_edit_prediction(&mut self, cx: &mut App) {
1424        if let Some(completion) = self.current_completion.take() {
1425            if !completion.was_accepted {
1426                self.zeta.update(cx, |zeta, cx| {
1427                    zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1428                });
1429            }
1430        }
1431    }
1432}
1433
1434impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1435    fn name() -> &'static str {
1436        "zed-predict"
1437    }
1438
1439    fn display_name() -> &'static str {
1440        "Zed's Edit Predictions"
1441    }
1442
1443    fn show_completions_in_menu() -> bool {
1444        true
1445    }
1446
1447    fn show_tab_accept_marker() -> bool {
1448        true
1449    }
1450
1451    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1452        if let Some(buffer) = &self.singleton_buffer
1453            && let Some(file) = buffer.read(cx).file()
1454        {
1455            let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1456            if self.zeta.read(cx).data_collection_choice.is_enabled() {
1457                DataCollectionState::Enabled {
1458                    is_project_open_source,
1459                }
1460            } else {
1461                DataCollectionState::Disabled {
1462                    is_project_open_source,
1463                }
1464            }
1465        } else {
1466            return DataCollectionState::Disabled {
1467                is_project_open_source: false,
1468            };
1469        }
1470    }
1471
1472    fn toggle_data_collection(&mut self, cx: &mut App) {
1473        self.zeta
1474            .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1475    }
1476
1477    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1478        self.zeta.read(cx).usage(cx)
1479    }
1480
1481    fn is_enabled(
1482        &self,
1483        _buffer: &Entity<Buffer>,
1484        _cursor_position: language::Anchor,
1485        _cx: &App,
1486    ) -> bool {
1487        true
1488    }
1489    fn is_refreshing(&self, _cx: &App) -> bool {
1490        !self.pending_completions.is_empty()
1491    }
1492
1493    fn refresh(
1494        &mut self,
1495        buffer: Entity<Buffer>,
1496        position: language::Anchor,
1497        _debounce: bool,
1498        cx: &mut Context<Self>,
1499    ) {
1500        if self.zeta.read(cx).update_required {
1501            return;
1502        }
1503
1504        if self
1505            .zeta
1506            .read(cx)
1507            .user_store
1508            .read_with(cx, |user_store, _cx| {
1509                user_store.account_too_young() || user_store.has_overdue_invoices()
1510            })
1511        {
1512            return;
1513        }
1514
1515        if let Some(current_completion) = self.current_completion.as_ref() {
1516            let snapshot = buffer.read(cx).snapshot();
1517            if current_completion
1518                .completion
1519                .interpolate(&snapshot)
1520                .is_some()
1521            {
1522                return;
1523            }
1524        }
1525
1526        let pending_completion_id = self.next_pending_completion_id;
1527        self.next_pending_completion_id += 1;
1528        let last_request_timestamp = self.last_request_timestamp;
1529
1530        let project = self.project.clone();
1531        let task = cx.spawn(async move |this, cx| {
1532            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1533                .checked_duration_since(Instant::now())
1534            {
1535                cx.background_executor().timer(timeout).await;
1536            }
1537
1538            let completion_request = this.update(cx, |this, cx| {
1539                this.last_request_timestamp = Instant::now();
1540                this.zeta.update(cx, |zeta, cx| {
1541                    zeta.request_completion(&project, &buffer, position, cx)
1542                })
1543            });
1544
1545            let completion = match completion_request {
1546                Ok(completion_request) => {
1547                    let completion_request = completion_request.await;
1548                    completion_request.map(|c| {
1549                        c.map(|completion| CurrentEditPrediction {
1550                            buffer_id: buffer.entity_id(),
1551                            completion,
1552                            was_shown: false,
1553                            was_accepted: false,
1554                        })
1555                    })
1556                }
1557                Err(error) => Err(error),
1558            };
1559
1560            let discarded = this
1561                .update(cx, |this, cx| {
1562                    if this
1563                        .pending_completions
1564                        .first()
1565                        .is_some_and(|completion| completion.id == pending_completion_id)
1566                    {
1567                        this.pending_completions.remove(0);
1568                    } else {
1569                        if let Some(discarded) = this.pending_completions.drain(..).next() {
1570                            this.canceled_completions
1571                                .insert(discarded.id, discarded.task);
1572                        }
1573                    }
1574
1575                    let canceled = this.canceled_completions.remove(&pending_completion_id);
1576
1577                    if canceled.is_some()
1578                        && let Ok(Some(new_completion)) = &completion
1579                    {
1580                        this.zeta.update(cx, |zeta, cx| {
1581                            zeta.discard_completion(new_completion.completion.id, false, cx);
1582                        });
1583                        return true;
1584                    }
1585
1586                    cx.notify();
1587                    false
1588                })
1589                .ok()
1590                .unwrap_or(true);
1591
1592            if discarded {
1593                return;
1594            }
1595
1596            let Some(new_completion) = completion
1597                .context("edit prediction failed")
1598                .log_err()
1599                .flatten()
1600            else {
1601                return;
1602            };
1603
1604            this.update(cx, |this, cx| {
1605                if let Some(old_completion) = this.current_completion.as_ref() {
1606                    let snapshot = buffer.read(cx).snapshot();
1607                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1608                        this.zeta.update(cx, |zeta, cx| {
1609                            zeta.completion_shown(&new_completion.completion, cx);
1610                        });
1611                        this.take_current_edit_prediction(cx);
1612                        this.current_completion = Some(new_completion);
1613                    }
1614                } else {
1615                    this.zeta.update(cx, |zeta, cx| {
1616                        zeta.completion_shown(&new_completion.completion, cx);
1617                    });
1618                    this.current_completion = Some(new_completion);
1619                }
1620
1621                cx.notify();
1622            })
1623            .ok();
1624        });
1625
1626        // We always maintain at most two pending completions. When we already
1627        // have two, we replace the newest one.
1628        if self.pending_completions.len() <= 1 {
1629            self.pending_completions.push(PendingCompletion {
1630                id: pending_completion_id,
1631                task,
1632            });
1633        } else if self.pending_completions.len() == 2 {
1634            if let Some(discarded) = self.pending_completions.pop() {
1635                self.canceled_completions
1636                    .insert(discarded.id, discarded.task);
1637            }
1638            self.pending_completions.push(PendingCompletion {
1639                id: pending_completion_id,
1640                task,
1641            });
1642        }
1643    }
1644
1645    fn cycle(
1646        &mut self,
1647        _buffer: Entity<Buffer>,
1648        _cursor_position: language::Anchor,
1649        _direction: edit_prediction::Direction,
1650        _cx: &mut Context<Self>,
1651    ) {
1652        // Right now we don't support cycling.
1653    }
1654
1655    fn accept(&mut self, cx: &mut Context<Self>) {
1656        let completion = self.current_completion.as_mut();
1657        if let Some(completion) = completion {
1658            completion.was_accepted = true;
1659            self.zeta
1660                .update(cx, |zeta, cx| {
1661                    zeta.accept_edit_prediction(completion.completion.id, cx)
1662                })
1663                .detach();
1664        }
1665        self.pending_completions.clear();
1666    }
1667
1668    fn discard(&mut self, cx: &mut Context<Self>) {
1669        self.pending_completions.clear();
1670        self.take_current_edit_prediction(cx);
1671    }
1672
1673    fn did_show(&mut self, _cx: &mut Context<Self>) {
1674        if let Some(current_completion) = self.current_completion.as_mut() {
1675            current_completion.was_shown = true;
1676        }
1677    }
1678
1679    fn suggest(
1680        &mut self,
1681        buffer: &Entity<Buffer>,
1682        cursor_position: language::Anchor,
1683        cx: &mut Context<Self>,
1684    ) -> Option<edit_prediction::EditPrediction> {
1685        let CurrentEditPrediction {
1686            buffer_id,
1687            completion,
1688            ..
1689        } = self.current_completion.as_mut()?;
1690
1691        // Invalidate previous completion if it was generated for a different buffer.
1692        if *buffer_id != buffer.entity_id() {
1693            self.take_current_edit_prediction(cx);
1694            return None;
1695        }
1696
1697        let buffer = buffer.read(cx);
1698        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1699            self.take_current_edit_prediction(cx);
1700            return None;
1701        };
1702
1703        let cursor_row = cursor_position.to_point(buffer).row;
1704        let (closest_edit_ix, (closest_edit_range, _)) =
1705            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1706                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1707                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1708                cmp::min(distance_from_start, distance_from_end)
1709            })?;
1710
1711        let mut edit_start_ix = closest_edit_ix;
1712        for (range, _) in edits[..edit_start_ix].iter().rev() {
1713            let distance_from_closest_edit =
1714                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1715            if distance_from_closest_edit <= 1 {
1716                edit_start_ix -= 1;
1717            } else {
1718                break;
1719            }
1720        }
1721
1722        let mut edit_end_ix = closest_edit_ix + 1;
1723        for (range, _) in &edits[edit_end_ix..] {
1724            let distance_from_closest_edit =
1725                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1726            if distance_from_closest_edit <= 1 {
1727                edit_end_ix += 1;
1728            } else {
1729                break;
1730            }
1731        }
1732
1733        Some(edit_prediction::EditPrediction::Local {
1734            id: Some(completion.id.to_string().into()),
1735            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1736            edit_preview: Some(completion.edit_preview.clone()),
1737        })
1738    }
1739}
1740
1741/// Typical number of string bytes per token for the purposes of limiting model input. This is
1742/// intentionally low to err on the side of underestimating limits.
1743const BYTES_PER_TOKEN_GUESS: usize = 3;
1744
1745fn guess_token_count(bytes: usize) -> usize {
1746    bytes / BYTES_PER_TOKEN_GUESS
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751    use client::test::FakeServer;
1752    use clock::{FakeSystemClock, ReplicaId};
1753    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1754    use gpui::TestAppContext;
1755    use http_client::FakeHttpClient;
1756    use indoc::indoc;
1757    use language::Point;
1758    use parking_lot::Mutex;
1759    use serde_json::json;
1760    use settings::SettingsStore;
1761    use util::{path, rel_path::rel_path};
1762
1763    use super::*;
1764
1765    const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1766
1767    #[gpui::test]
1768    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1769        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1770        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1771            to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1772        });
1773
1774        let edit_preview = cx
1775            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1776            .await;
1777
1778        let completion = EditPrediction {
1779            edits,
1780            edit_preview,
1781            path: Path::new("").into(),
1782            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1783            id: EditPredictionId(Uuid::new_v4()),
1784            excerpt_range: 0..0,
1785            cursor_offset: 0,
1786            input_outline: "".into(),
1787            input_events: "".into(),
1788            input_excerpt: "".into(),
1789            output_excerpt: "".into(),
1790            buffer_snapshotted_at: Instant::now(),
1791            response_received_at: Instant::now(),
1792        };
1793
1794        cx.update(|cx| {
1795            assert_eq!(
1796                from_completion_edits(
1797                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1798                    &buffer,
1799                    cx
1800                ),
1801                vec![(2..5, "REM".into()), (9..11, "".into())]
1802            );
1803
1804            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1805            assert_eq!(
1806                from_completion_edits(
1807                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1808                    &buffer,
1809                    cx
1810                ),
1811                vec![(2..2, "REM".into()), (6..8, "".into())]
1812            );
1813
1814            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1815            assert_eq!(
1816                from_completion_edits(
1817                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1818                    &buffer,
1819                    cx
1820                ),
1821                vec![(2..5, "REM".into()), (9..11, "".into())]
1822            );
1823
1824            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1825            assert_eq!(
1826                from_completion_edits(
1827                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1828                    &buffer,
1829                    cx
1830                ),
1831                vec![(3..3, "EM".into()), (7..9, "".into())]
1832            );
1833
1834            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1835            assert_eq!(
1836                from_completion_edits(
1837                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1838                    &buffer,
1839                    cx
1840                ),
1841                vec![(4..4, "M".into()), (8..10, "".into())]
1842            );
1843
1844            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1845            assert_eq!(
1846                from_completion_edits(
1847                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1848                    &buffer,
1849                    cx
1850                ),
1851                vec![(9..11, "".into())]
1852            );
1853
1854            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1855            assert_eq!(
1856                from_completion_edits(
1857                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1858                    &buffer,
1859                    cx
1860                ),
1861                vec![(4..4, "M".into()), (8..10, "".into())]
1862            );
1863
1864            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1865            assert_eq!(
1866                from_completion_edits(
1867                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1868                    &buffer,
1869                    cx
1870                ),
1871                vec![(4..4, "M".into())]
1872            );
1873
1874            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1875            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1876        })
1877    }
1878
1879    #[gpui::test]
1880    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1881        init_test(cx);
1882
1883        assert_eq!(
1884            apply_edit_prediction(
1885                indoc! {"
1886                    fn main() {
1887                        let word_1 = \"lorem\";
1888                        let range = word.len()..word.len();
1889                    }
1890                "},
1891                indoc! {"
1892                    <|editable_region_start|>
1893                    fn main() {
1894                        let word_1 = \"lorem\";
1895                        let range = word_1.len()..word_1.len();
1896                    }
1897
1898                    <|editable_region_end|>
1899                "},
1900                cx,
1901            )
1902            .await,
1903            indoc! {"
1904                fn main() {
1905                    let word_1 = \"lorem\";
1906                    let range = word_1.len()..word_1.len();
1907                }
1908            "},
1909        );
1910
1911        assert_eq!(
1912            apply_edit_prediction(
1913                indoc! {"
1914                    fn main() {
1915                        let story = \"the quick\"
1916                    }
1917                "},
1918                indoc! {"
1919                    <|editable_region_start|>
1920                    fn main() {
1921                        let story = \"the quick brown fox jumps over the lazy dog\";
1922                    }
1923
1924                    <|editable_region_end|>
1925                "},
1926                cx,
1927            )
1928            .await,
1929            indoc! {"
1930                fn main() {
1931                    let story = \"the quick brown fox jumps over the lazy dog\";
1932                }
1933            "},
1934        );
1935    }
1936
1937    #[gpui::test]
1938    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1939        init_test(cx);
1940
1941        let buffer_content = "lorem\n";
1942        let completion_response = indoc! {"
1943            ```animals.js
1944            <|start_of_file|>
1945            <|editable_region_start|>
1946            lorem
1947            ipsum
1948            <|editable_region_end|>
1949            ```"};
1950
1951        assert_eq!(
1952            apply_edit_prediction(buffer_content, completion_response, cx).await,
1953            "lorem\nipsum"
1954        );
1955    }
1956
1957    #[gpui::test]
1958    async fn test_can_collect_data(cx: &mut TestAppContext) {
1959        init_test(cx);
1960
1961        let fs = project::FakeFs::new(cx.executor());
1962        fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1963            .await;
1964
1965        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1966        let buffer = project
1967            .update(cx, |project, cx| {
1968                project.open_local_buffer(path!("/project/src/main.rs"), cx)
1969            })
1970            .await
1971            .unwrap();
1972
1973        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1974        zeta.update(cx, |zeta, _cx| {
1975            zeta.data_collection_choice = DataCollectionChoice::Enabled
1976        });
1977
1978        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1979        assert_eq!(
1980            captured_request.lock().clone().unwrap().can_collect_data,
1981            true
1982        );
1983
1984        zeta.update(cx, |zeta, _cx| {
1985            zeta.data_collection_choice = DataCollectionChoice::Disabled
1986        });
1987
1988        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1989        assert_eq!(
1990            captured_request.lock().clone().unwrap().can_collect_data,
1991            false
1992        );
1993    }
1994
1995    #[gpui::test]
1996    async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1997        init_test(cx);
1998
1999        let fs = project::FakeFs::new(cx.executor());
2000        let project = Project::test(fs.clone(), [], cx).await;
2001
2002        let buffer = cx.new(|_cx| {
2003            Buffer::remote(
2004                language::BufferId::new(1).unwrap(),
2005                ReplicaId::new(1),
2006                language::Capability::ReadWrite,
2007                "fn main() {\n    println!(\"Hello\");\n}",
2008            )
2009        });
2010
2011        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2012        zeta.update(cx, |zeta, _cx| {
2013            zeta.data_collection_choice = DataCollectionChoice::Enabled
2014        });
2015
2016        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2017        assert_eq!(
2018            captured_request.lock().clone().unwrap().can_collect_data,
2019            false
2020        );
2021    }
2022
2023    #[gpui::test]
2024    async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2025        init_test(cx);
2026
2027        let fs = project::FakeFs::new(cx.executor());
2028        fs.insert_tree(
2029            path!("/project"),
2030            json!({
2031                "LICENSE": BSD_0_TXT,
2032                ".env": "SECRET_KEY=secret"
2033            }),
2034        )
2035        .await;
2036
2037        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2038        let buffer = project
2039            .update(cx, |project, cx| {
2040                project.open_local_buffer("/project/.env", cx)
2041            })
2042            .await
2043            .unwrap();
2044
2045        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2046        zeta.update(cx, |zeta, _cx| {
2047            zeta.data_collection_choice = DataCollectionChoice::Enabled
2048        });
2049
2050        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2051        assert_eq!(
2052            captured_request.lock().clone().unwrap().can_collect_data,
2053            false
2054        );
2055    }
2056
2057    #[gpui::test]
2058    async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2059        init_test(cx);
2060
2061        let fs = project::FakeFs::new(cx.executor());
2062        let project = Project::test(fs.clone(), [], cx).await;
2063        let buffer = cx.new(|cx| Buffer::local("", cx));
2064
2065        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2066        zeta.update(cx, |zeta, _cx| {
2067            zeta.data_collection_choice = DataCollectionChoice::Enabled
2068        });
2069
2070        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2071        assert_eq!(
2072            captured_request.lock().clone().unwrap().can_collect_data,
2073            false
2074        );
2075    }
2076
2077    #[gpui::test]
2078    async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2079        init_test(cx);
2080
2081        let fs = project::FakeFs::new(cx.executor());
2082        fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2083            .await;
2084
2085        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2086        let buffer = project
2087            .update(cx, |project, cx| {
2088                project.open_local_buffer("/project/main.rs", cx)
2089            })
2090            .await
2091            .unwrap();
2092
2093        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2094        zeta.update(cx, |zeta, _cx| {
2095            zeta.data_collection_choice = DataCollectionChoice::Enabled
2096        });
2097
2098        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2099        assert_eq!(
2100            captured_request.lock().clone().unwrap().can_collect_data,
2101            false
2102        );
2103    }
2104
2105    #[gpui::test]
2106    async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2107        init_test(cx);
2108
2109        let fs = project::FakeFs::new(cx.executor());
2110        fs.insert_tree(
2111            path!("/open_source_worktree"),
2112            json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2113        )
2114        .await;
2115        fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2116            .await;
2117
2118        let project = Project::test(
2119            fs.clone(),
2120            [
2121                path!("/open_source_worktree").as_ref(),
2122                path!("/closed_source_worktree").as_ref(),
2123            ],
2124            cx,
2125        )
2126        .await;
2127        let buffer = project
2128            .update(cx, |project, cx| {
2129                project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2130            })
2131            .await
2132            .unwrap();
2133
2134        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2135        zeta.update(cx, |zeta, _cx| {
2136            zeta.data_collection_choice = DataCollectionChoice::Enabled
2137        });
2138
2139        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2140        assert_eq!(
2141            captured_request.lock().clone().unwrap().can_collect_data,
2142            true
2143        );
2144
2145        let closed_source_file = project
2146            .update(cx, |project, cx| {
2147                let worktree2 = project
2148                    .worktree_for_root_name("closed_source_worktree", cx)
2149                    .unwrap();
2150                worktree2.update(cx, |worktree2, cx| {
2151                    worktree2.load_file(rel_path("main.rs"), cx)
2152                })
2153            })
2154            .await
2155            .unwrap()
2156            .file;
2157
2158        buffer.update(cx, |buffer, cx| {
2159            buffer.file_updated(closed_source_file, cx);
2160        });
2161
2162        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2163        assert_eq!(
2164            captured_request.lock().clone().unwrap().can_collect_data,
2165            false
2166        );
2167    }
2168
2169    #[gpui::test]
2170    async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2171        init_test(cx);
2172
2173        let fs = project::FakeFs::new(cx.executor());
2174        fs.insert_tree(
2175            path!("/worktree1"),
2176            json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2177        )
2178        .await;
2179        fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2180            .await;
2181
2182        let project = Project::test(
2183            fs.clone(),
2184            [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2185            cx,
2186        )
2187        .await;
2188        let buffer = project
2189            .update(cx, |project, cx| {
2190                project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2191            })
2192            .await
2193            .unwrap();
2194        let private_buffer = project
2195            .update(cx, |project, cx| {
2196                project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2197            })
2198            .await
2199            .unwrap();
2200
2201        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2202        zeta.update(cx, |zeta, _cx| {
2203            zeta.data_collection_choice = DataCollectionChoice::Enabled
2204        });
2205
2206        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2207        assert_eq!(
2208            captured_request.lock().clone().unwrap().can_collect_data,
2209            true
2210        );
2211
2212        // this has a side effect of registering the buffer to watch for edits
2213        run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2214        assert_eq!(
2215            captured_request.lock().clone().unwrap().can_collect_data,
2216            false
2217        );
2218
2219        private_buffer.update(cx, |private_buffer, cx| {
2220            private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2221        });
2222
2223        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2224        assert_eq!(
2225            captured_request.lock().clone().unwrap().can_collect_data,
2226            false
2227        );
2228
2229        // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2230        // included
2231        buffer.update(cx, |buffer, cx| {
2232            buffer.edit(
2233                [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2234                None,
2235                cx,
2236            );
2237        });
2238
2239        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2240        assert_eq!(
2241            captured_request.lock().clone().unwrap().can_collect_data,
2242            true
2243        );
2244    }
2245
2246    fn init_test(cx: &mut TestAppContext) {
2247        cx.update(|cx| {
2248            let settings_store = SettingsStore::test(cx);
2249            cx.set_global(settings_store);
2250        });
2251    }
2252
2253    async fn apply_edit_prediction(
2254        buffer_content: &str,
2255        completion_response: &str,
2256        cx: &mut TestAppContext,
2257    ) -> String {
2258        let fs = project::FakeFs::new(cx.executor());
2259        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2260        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2261        let (zeta, _, response) = make_test_zeta(&project, cx).await;
2262        *response.lock() = completion_response.to_string();
2263        let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2264        buffer.update(cx, |buffer, cx| {
2265            buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2266        });
2267        buffer.read_with(cx, |buffer, _| buffer.text())
2268    }
2269
2270    async fn run_edit_prediction(
2271        buffer: &Entity<Buffer>,
2272        project: &Entity<Project>,
2273        zeta: &Entity<Zeta>,
2274        cx: &mut TestAppContext,
2275    ) -> EditPrediction {
2276        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2277        zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2278        cx.background_executor.run_until_parked();
2279        let completion_task = zeta.update(cx, |zeta, cx| {
2280            zeta.request_completion(&project, buffer, cursor, cx)
2281        });
2282        completion_task.await.unwrap().unwrap()
2283    }
2284
2285    async fn make_test_zeta(
2286        project: &Entity<Project>,
2287        cx: &mut TestAppContext,
2288    ) -> (
2289        Entity<Zeta>,
2290        Arc<Mutex<Option<PredictEditsBody>>>,
2291        Arc<Mutex<String>>,
2292    ) {
2293        let default_response = indoc! {"
2294            ```main.rs
2295            <|start_of_file|>
2296            <|editable_region_start|>
2297            hello world
2298            <|editable_region_end|>
2299            ```"
2300        };
2301        let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2302        let completion_response: Arc<Mutex<String>> =
2303            Arc::new(Mutex::new(default_response.to_string()));
2304        let http_client = FakeHttpClient::create({
2305            let captured_request = captured_request.clone();
2306            let completion_response = completion_response.clone();
2307            move |req| {
2308                let captured_request = captured_request.clone();
2309                let completion_response = completion_response.clone();
2310                async move {
2311                    match (req.method(), req.uri().path()) {
2312                        (&Method::POST, "/client/llm_tokens") => {
2313                            Ok(http_client::Response::builder()
2314                                .status(200)
2315                                .body(
2316                                    serde_json::to_string(&CreateLlmTokenResponse {
2317                                        token: LlmToken("the-llm-token".to_string()),
2318                                    })
2319                                    .unwrap()
2320                                    .into(),
2321                                )
2322                                .unwrap())
2323                        }
2324                        (&Method::POST, "/predict_edits/v2") => {
2325                            let mut request_body = String::new();
2326                            req.into_body().read_to_string(&mut request_body).await?;
2327                            *captured_request.lock() =
2328                                Some(serde_json::from_str(&request_body).unwrap());
2329                            Ok(http_client::Response::builder()
2330                                .status(200)
2331                                .body(
2332                                    serde_json::to_string(&PredictEditsResponse {
2333                                        request_id: Uuid::new_v4().to_string(),
2334                                        output_excerpt: completion_response.lock().clone(),
2335                                    })
2336                                    .unwrap()
2337                                    .into(),
2338                                )
2339                                .unwrap())
2340                        }
2341                        _ => Ok(http_client::Response::builder()
2342                            .status(404)
2343                            .body("Not Found".into())
2344                            .unwrap()),
2345                    }
2346                }
2347            }
2348        });
2349
2350        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2351        cx.update(|cx| {
2352            RefreshLlmTokenListener::register(client.clone(), cx);
2353        });
2354        let _server = FakeServer::for_client(42, &client, cx).await;
2355
2356        let zeta = cx.new(|cx| {
2357            let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2358
2359            let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2360            for worktree in worktrees {
2361                let worktree_id = worktree.read(cx).id();
2362                zeta.license_detection_watchers
2363                    .entry(worktree_id)
2364                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2365            }
2366
2367            zeta
2368        });
2369
2370        (zeta, captured_request, completion_response)
2371    }
2372
2373    fn to_completion_edits(
2374        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2375        buffer: &Entity<Buffer>,
2376        cx: &App,
2377    ) -> Vec<(Range<Anchor>, Arc<str>)> {
2378        let buffer = buffer.read(cx);
2379        iterator
2380            .into_iter()
2381            .map(|(range, text)| {
2382                (
2383                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2384                    text,
2385                )
2386            })
2387            .collect()
2388    }
2389
2390    fn from_completion_edits(
2391        editor_edits: &[(Range<Anchor>, Arc<str>)],
2392        buffer: &Entity<Buffer>,
2393        cx: &App,
2394    ) -> Vec<(Range<usize>, Arc<str>)> {
2395        let buffer = buffer.read(cx);
2396        editor_edits
2397            .iter()
2398            .map(|(range, text)| {
2399                (
2400                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2401                    text.clone(),
2402                )
2403            })
2404            .collect()
2405    }
2406
2407    #[ctor::ctor]
2408    fn init_logger() {
2409        zlog::init_test();
2410    }
2411}