zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use db::smol::stream::StreamExt as _;
  12use edit_prediction::DataCollectionState;
  13use futures::channel::mpsc;
  14pub use init::*;
  15use license_detection::LicenseDetectionWatcher;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use arrayvec::ArrayVec;
  20use client::{Client, EditPredictionUsage, UserStore};
  21use cloud_llm_client::{
  22    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
  23    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  24    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
  25    ZED_VERSION_HEADER_NAME,
  26};
  27use collections::{HashMap, HashSet, VecDeque};
  28use futures::AsyncReadExt;
  29use gpui::{
  30    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SharedString, Subscription,
  31    Task, actions,
  32};
  33use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  34use input_excerpt::excerpt_for_cursor_position;
  35use language::{
  36    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  37};
  38use language_model::{LlmApiToken, RefreshLlmTokenListener};
  39use project::{Project, ProjectPath};
  40use release_channel::AppVersion;
  41use semver::Version;
  42use settings::WorktreeId;
  43use std::collections::hash_map;
  44use std::mem;
  45use std::str::FromStr;
  46use std::{
  47    cmp,
  48    fmt::Write,
  49    future::Future,
  50    ops::Range,
  51    path::Path,
  52    rc::Rc,
  53    sync::Arc,
  54    time::{Duration, Instant},
  55};
  56use telemetry_events::EditPredictionRating;
  57use thiserror::Error;
  58use util::ResultExt;
  59use util::rel_path::RelPath;
  60use uuid::Uuid;
  61use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  62use worktree::Worktree;
  63
  64const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  65const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  66const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  67const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  68const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  69const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  70
  71const MAX_CONTEXT_TOKENS: usize = 150;
  72const MAX_REWRITE_TOKENS: usize = 350;
  73const MAX_EVENT_TOKENS: usize = 500;
  74
  75/// Maximum number of events to track.
  76const MAX_EVENT_COUNT: usize = 16;
  77
  78actions!(
  79    edit_prediction,
  80    [
  81        /// Clears the edit prediction history.
  82        ClearHistory
  83    ]
  84);
  85
  86#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  87pub struct EditPredictionId(Uuid);
  88
  89impl From<EditPredictionId> for gpui::ElementId {
  90    fn from(value: EditPredictionId) -> Self {
  91        gpui::ElementId::Uuid(value.0)
  92    }
  93}
  94
  95impl std::fmt::Display for EditPredictionId {
  96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  97        write!(f, "{}", self.0)
  98    }
  99}
 100
 101struct ZedPredictUpsell;
 102
 103impl Dismissable for ZedPredictUpsell {
 104    const KEY: &'static str = "dismissed-edit-predict-upsell";
 105
 106    fn dismissed() -> bool {
 107        // To make this backwards compatible with older versions of Zed, we
 108        // check if the user has seen the previous Edit Prediction Onboarding
 109        // before, by checking the data collection choice which was written to
 110        // the database once the user clicked on "Accept and Enable"
 111        if KEY_VALUE_STORE
 112            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 113            .log_err()
 114            .is_some_and(|s| s.is_some())
 115        {
 116            return true;
 117        }
 118
 119        KEY_VALUE_STORE
 120            .read_kvp(Self::KEY)
 121            .log_err()
 122            .is_some_and(|s| s.is_some())
 123    }
 124}
 125
 126pub fn should_show_upsell_modal() -> bool {
 127    !ZedPredictUpsell::dismissed()
 128}
 129
 130#[derive(Clone)]
 131struct ZetaGlobal(Entity<Zeta>);
 132
 133impl Global for ZetaGlobal {}
 134
 135#[derive(Clone)]
 136pub struct EditPrediction {
 137    id: EditPredictionId,
 138    path: Arc<Path>,
 139    excerpt_range: Range<usize>,
 140    cursor_offset: usize,
 141    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 142    snapshot: BufferSnapshot,
 143    edit_preview: EditPreview,
 144    input_outline: Arc<str>,
 145    input_events: Arc<str>,
 146    input_excerpt: Arc<str>,
 147    output_excerpt: Arc<str>,
 148    buffer_snapshotted_at: Instant,
 149    response_received_at: Instant,
 150}
 151
 152impl EditPrediction {
 153    fn latency(&self) -> Duration {
 154        self.response_received_at
 155            .duration_since(self.buffer_snapshotted_at)
 156    }
 157
 158    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 159        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 160    }
 161}
 162
 163impl std::fmt::Debug for EditPrediction {
 164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 165        f.debug_struct("EditPrediction")
 166            .field("id", &self.id)
 167            .field("path", &self.path)
 168            .field("edits", &self.edits)
 169            .finish_non_exhaustive()
 170    }
 171}
 172
 173pub struct Zeta {
 174    projects: HashMap<EntityId, ZetaProject>,
 175    client: Arc<Client>,
 176    shown_completions: VecDeque<EditPrediction>,
 177    rated_completions: HashSet<EditPredictionId>,
 178    data_collection_choice: DataCollectionChoice,
 179    discarded_completions: Vec<EditPredictionRejection>,
 180    llm_token: LlmApiToken,
 181    _llm_token_subscription: Subscription,
 182    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 183    update_required: bool,
 184    user_store: Entity<UserStore>,
 185    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 186    discard_completions_debounce_task: Option<Task<()>>,
 187    discard_completions_tx: mpsc::UnboundedSender<()>,
 188}
 189
 190struct ZetaProject {
 191    events: VecDeque<Event>,
 192    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 193}
 194
 195impl Zeta {
 196    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 197        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 198    }
 199
 200    pub fn register(
 201        worktree: Option<Entity<Worktree>>,
 202        client: Arc<Client>,
 203        user_store: Entity<UserStore>,
 204        cx: &mut App,
 205    ) -> Entity<Self> {
 206        let this = Self::global(cx).unwrap_or_else(|| {
 207            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 208            cx.set_global(ZetaGlobal(entity.clone()));
 209            entity
 210        });
 211
 212        this.update(cx, move |this, cx| {
 213            if let Some(worktree) = worktree {
 214                let worktree_id = worktree.read(cx).id();
 215                this.license_detection_watchers
 216                    .entry(worktree_id)
 217                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 218            }
 219        });
 220
 221        this
 222    }
 223
 224    pub fn clear_history(&mut self) {
 225        for zeta_project in self.projects.values_mut() {
 226            zeta_project.events.clear();
 227        }
 228    }
 229
 230    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 231        self.user_store.read(cx).edit_prediction_usage()
 232    }
 233
 234    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 235        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 236        let data_collection_choice = Self::load_data_collection_choice();
 237        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 238        cx.spawn(async move |this, cx| {
 239            while let Some(()) = reject_rx.next().await {
 240                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 241                    .await
 242                    .log_err();
 243            }
 244            anyhow::Ok(())
 245        })
 246        .detach();
 247
 248        Self {
 249            projects: HashMap::default(),
 250            client,
 251            shown_completions: VecDeque::new(),
 252            rated_completions: HashSet::default(),
 253            discarded_completions: Vec::new(),
 254            discard_completions_debounce_task: None,
 255            discard_completions_tx: reject_tx,
 256            data_collection_choice,
 257            llm_token: LlmApiToken::default(),
 258            _llm_token_subscription: cx.subscribe(
 259                &refresh_llm_token_listener,
 260                |this, _listener, _event, cx| {
 261                    let client = this.client.clone();
 262                    let llm_token = this.llm_token.clone();
 263                    cx.spawn(async move |_this, _cx| {
 264                        llm_token.refresh(&client).await?;
 265                        anyhow::Ok(())
 266                    })
 267                    .detach_and_log_err(cx);
 268                },
 269            ),
 270            update_required: false,
 271            license_detection_watchers: HashMap::default(),
 272            user_store,
 273        }
 274    }
 275
 276    fn get_or_init_zeta_project(
 277        &mut self,
 278        project: &Entity<Project>,
 279        cx: &mut Context<Self>,
 280    ) -> &mut ZetaProject {
 281        let project_id = project.entity_id();
 282        match self.projects.entry(project_id) {
 283            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 284            hash_map::Entry::Vacant(entry) => {
 285                cx.observe_release(project, move |this, _, _cx| {
 286                    this.projects.remove(&project_id);
 287                })
 288                .detach();
 289                entry.insert(ZetaProject {
 290                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 291                    registered_buffers: HashMap::default(),
 292                })
 293            }
 294        }
 295    }
 296
 297    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 298        let events = &mut zeta_project.events;
 299
 300        if let Some(Event::BufferChange {
 301            new_snapshot: last_new_snapshot,
 302            timestamp: last_timestamp,
 303            ..
 304        }) = events.back_mut()
 305        {
 306            // Coalesce edits for the same buffer when they happen one after the other.
 307            let Event::BufferChange {
 308                old_snapshot,
 309                new_snapshot,
 310                timestamp,
 311            } = &event;
 312
 313            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 314                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 315                && old_snapshot.version == last_new_snapshot.version
 316            {
 317                *last_new_snapshot = new_snapshot.clone();
 318                *last_timestamp = *timestamp;
 319                return;
 320            }
 321        }
 322
 323        if events.len() >= MAX_EVENT_COUNT {
 324            // These are halved instead of popping to improve prompt caching.
 325            events.drain(..MAX_EVENT_COUNT / 2);
 326        }
 327
 328        events.push_back(event);
 329    }
 330
 331    pub fn register_buffer(
 332        &mut self,
 333        buffer: &Entity<Buffer>,
 334        project: &Entity<Project>,
 335        cx: &mut Context<Self>,
 336    ) {
 337        let zeta_project = self.get_or_init_zeta_project(project, cx);
 338        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 339    }
 340
 341    fn register_buffer_impl<'a>(
 342        zeta_project: &'a mut ZetaProject,
 343        buffer: &Entity<Buffer>,
 344        project: &Entity<Project>,
 345        cx: &mut Context<Self>,
 346    ) -> &'a mut RegisteredBuffer {
 347        let buffer_id = buffer.entity_id();
 348        match zeta_project.registered_buffers.entry(buffer_id) {
 349            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 350            hash_map::Entry::Vacant(entry) => {
 351                let snapshot = buffer.read(cx).snapshot();
 352                let project_entity_id = project.entity_id();
 353                entry.insert(RegisteredBuffer {
 354                    snapshot,
 355                    _subscriptions: [
 356                        cx.subscribe(buffer, {
 357                            let project = project.downgrade();
 358                            move |this, buffer, event, cx| {
 359                                if let language::BufferEvent::Edited = event
 360                                    && let Some(project) = project.upgrade()
 361                                {
 362                                    this.report_changes_for_buffer(&buffer, &project, cx);
 363                                }
 364                            }
 365                        }),
 366                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 367                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 368                            else {
 369                                return;
 370                            };
 371                            zeta_project.registered_buffers.remove(&buffer_id);
 372                        }),
 373                    ],
 374                })
 375            }
 376        }
 377    }
 378
 379    fn request_completion_impl<F, R>(
 380        &mut self,
 381        project: &Entity<Project>,
 382        buffer: &Entity<Buffer>,
 383        cursor: language::Anchor,
 384        cx: &mut Context<Self>,
 385        perform_predict_edits: F,
 386    ) -> Task<Result<Option<EditPrediction>>>
 387    where
 388        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 389        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 390            + Send
 391            + 'static,
 392    {
 393        let buffer = buffer.clone();
 394        let buffer_snapshotted_at = Instant::now();
 395        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 396        let zeta = cx.entity();
 397        let client = self.client.clone();
 398        let llm_token = self.llm_token.clone();
 399        let app_version = AppVersion::global(cx);
 400
 401        let zeta_project = self.get_or_init_zeta_project(project, cx);
 402        let mut events = Vec::with_capacity(zeta_project.events.len());
 403        events.extend(zeta_project.events.iter().cloned());
 404        let events = Arc::new(events);
 405
 406        let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 407            let can_collect_file = self.can_collect_file(file, cx);
 408            let git_info = if can_collect_file {
 409                git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 410            } else {
 411                None
 412            };
 413            (git_info, can_collect_file)
 414        } else {
 415            (None, false)
 416        };
 417
 418        let full_path: Arc<Path> = snapshot
 419            .file()
 420            .map(|f| Arc::from(f.full_path(cx).as_path()))
 421            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 422        let full_path_str = full_path.to_string_lossy().into_owned();
 423        let cursor_point = cursor.to_point(&snapshot);
 424        let cursor_offset = cursor_point.to_offset(&snapshot);
 425        let prompt_for_events = {
 426            let events = events.clone();
 427            move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 428        };
 429        let gather_task = gather_context(
 430            full_path_str,
 431            &snapshot,
 432            cursor_point,
 433            prompt_for_events,
 434            cx,
 435        );
 436
 437        cx.spawn(async move |this, cx| {
 438            let GatherContextOutput {
 439                mut body,
 440                editable_range,
 441                included_events_count,
 442            } = gather_task.await?;
 443            let done_gathering_context_at = Instant::now();
 444
 445            let included_events = &events[events.len() - included_events_count..events.len()];
 446            body.can_collect_data = can_collect_file
 447                && this
 448                    .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
 449                    .unwrap_or(false);
 450            if body.can_collect_data {
 451                body.git_info = git_info;
 452            }
 453
 454            log::debug!(
 455                "Events:\n{}\nExcerpt:\n{:?}",
 456                body.input_events,
 457                body.input_excerpt
 458            );
 459
 460            let input_outline = body.outline.clone().unwrap_or_default();
 461            let input_events = body.input_events.clone();
 462            let input_excerpt = body.input_excerpt.clone();
 463
 464            let response = perform_predict_edits(PerformPredictEditsParams {
 465                client,
 466                llm_token,
 467                app_version,
 468                body,
 469            })
 470            .await;
 471            let (response, usage) = match response {
 472                Ok(response) => response,
 473                Err(err) => {
 474                    if err.is::<ZedUpdateRequiredError>() {
 475                        cx.update(|cx| {
 476                            zeta.update(cx, |zeta, _cx| {
 477                                zeta.update_required = true;
 478                            });
 479
 480                            let error_message: SharedString = err.to_string().into();
 481                            show_app_notification(
 482                                NotificationId::unique::<ZedUpdateRequiredError>(),
 483                                cx,
 484                                move |cx| {
 485                                    cx.new(|cx| {
 486                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 487                                            .with_link_button(
 488                                                "Update Zed",
 489                                                "https://zed.dev/releases",
 490                                            )
 491                                    })
 492                                },
 493                            );
 494                        })
 495                        .ok();
 496                    }
 497
 498                    return Err(err);
 499                }
 500            };
 501
 502            let received_response_at = Instant::now();
 503            log::debug!("completion response: {}", &response.output_excerpt);
 504
 505            if let Some(usage) = usage {
 506                this.update(cx, |this, cx| {
 507                    this.user_store.update(cx, |user_store, cx| {
 508                        user_store.update_edit_prediction_usage(usage, cx);
 509                    });
 510                })
 511                .ok();
 512            }
 513
 514            let edit_prediction = Self::process_completion_response(
 515                response,
 516                buffer,
 517                &snapshot,
 518                editable_range,
 519                cursor_offset,
 520                full_path,
 521                input_outline,
 522                input_events,
 523                input_excerpt,
 524                buffer_snapshotted_at,
 525                cx,
 526            )
 527            .await;
 528
 529            let finished_at = Instant::now();
 530
 531            // record latency for ~1% of requests
 532            if rand::random::<u8>() <= 2 {
 533                telemetry::event!(
 534                    "Edit Prediction Request",
 535                    context_latency = done_gathering_context_at
 536                        .duration_since(buffer_snapshotted_at)
 537                        .as_millis(),
 538                    request_latency = received_response_at
 539                        .duration_since(done_gathering_context_at)
 540                        .as_millis(),
 541                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 542                );
 543            }
 544
 545            edit_prediction
 546        })
 547    }
 548
 549    #[cfg(any(test, feature = "test-support"))]
 550    pub fn fake_completion(
 551        &mut self,
 552        project: &Entity<Project>,
 553        buffer: &Entity<Buffer>,
 554        position: language::Anchor,
 555        response: PredictEditsResponse,
 556        cx: &mut Context<Self>,
 557    ) -> Task<Result<Option<EditPrediction>>> {
 558        self.request_completion_impl(project, buffer, position, cx, |_params| {
 559            std::future::ready(Ok((response, None)))
 560        })
 561    }
 562
 563    pub fn request_completion(
 564        &mut self,
 565        project: &Entity<Project>,
 566        buffer: &Entity<Buffer>,
 567        position: language::Anchor,
 568        cx: &mut Context<Self>,
 569    ) -> Task<Result<Option<EditPrediction>>> {
 570        self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits)
 571    }
 572
 573    pub fn perform_predict_edits(
 574        params: PerformPredictEditsParams,
 575    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 576        async move {
 577            let PerformPredictEditsParams {
 578                client,
 579                llm_token,
 580                app_version,
 581                body,
 582                ..
 583            } = params;
 584
 585            let http_client = client.http_client();
 586            let mut token = llm_token.acquire(&client).await?;
 587            let mut did_retry = false;
 588
 589            loop {
 590                let request_builder = http_client::Request::builder().method(Method::POST);
 591                let request_builder =
 592                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 593                        request_builder.uri(predict_edits_url)
 594                    } else {
 595                        request_builder.uri(
 596                            http_client
 597                                .build_zed_llm_url("/predict_edits/v2", &[])?
 598                                .as_ref(),
 599                        )
 600                    };
 601                let request = request_builder
 602                    .header("Content-Type", "application/json")
 603                    .header("Authorization", format!("Bearer {}", token))
 604                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 605                    .body(serde_json::to_string(&body)?.into())?;
 606
 607                let mut response = http_client.send(request).await?;
 608
 609                if let Some(minimum_required_version) = response
 610                    .headers()
 611                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 612                    .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
 613                {
 614                    anyhow::ensure!(
 615                        app_version >= minimum_required_version,
 616                        ZedUpdateRequiredError {
 617                            minimum_version: minimum_required_version
 618                        }
 619                    );
 620                }
 621
 622                if response.status().is_success() {
 623                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 624
 625                    let mut body = String::new();
 626                    response.body_mut().read_to_string(&mut body).await?;
 627                    return Ok((serde_json::from_str(&body)?, usage));
 628                } else if !did_retry
 629                    && response
 630                        .headers()
 631                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 632                        .is_some()
 633                {
 634                    did_retry = true;
 635                    token = llm_token.refresh(&client).await?;
 636                } else {
 637                    let mut body = String::new();
 638                    response.body_mut().read_to_string(&mut body).await?;
 639                    anyhow::bail!(
 640                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 641                        response.status(),
 642                        body
 643                    );
 644                }
 645            }
 646        }
 647    }
 648
 649    fn accept_edit_prediction(
 650        &mut self,
 651        request_id: EditPredictionId,
 652        cx: &mut Context<Self>,
 653    ) -> Task<Result<()>> {
 654        let client = self.client.clone();
 655        let llm_token = self.llm_token.clone();
 656        let app_version = AppVersion::global(cx);
 657        cx.spawn(async move |this, cx| {
 658            let http_client = client.http_client();
 659            let mut response = llm_token_retry(&llm_token, &client, |token| {
 660                let request_builder = http_client::Request::builder().method(Method::POST);
 661                let request_builder =
 662                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 663                        request_builder.uri(accept_prediction_url)
 664                    } else {
 665                        request_builder.uri(
 666                            http_client
 667                                .build_zed_llm_url("/predict_edits/accept", &[])?
 668                                .as_ref(),
 669                        )
 670                    };
 671                Ok(request_builder
 672                    .header("Content-Type", "application/json")
 673                    .header("Authorization", format!("Bearer {}", token))
 674                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 675                    .body(
 676                        serde_json::to_string(&AcceptEditPredictionBody {
 677                            request_id: request_id.0.to_string(),
 678                        })?
 679                        .into(),
 680                    )?)
 681            })
 682            .await?;
 683
 684            if let Some(minimum_required_version) = response
 685                .headers()
 686                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 687                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
 688                && app_version < minimum_required_version
 689            {
 690                return Err(anyhow!(ZedUpdateRequiredError {
 691                    minimum_version: minimum_required_version
 692                }));
 693            }
 694
 695            if response.status().is_success() {
 696                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 697                    this.update(cx, |this, cx| {
 698                        this.user_store.update(cx, |user_store, cx| {
 699                            user_store.update_edit_prediction_usage(usage, cx);
 700                        });
 701                    })?;
 702                }
 703
 704                Ok(())
 705            } else {
 706                let mut body = String::new();
 707                response.body_mut().read_to_string(&mut body).await?;
 708                Err(anyhow!(
 709                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 710                    response.status(),
 711                    body
 712                ))
 713            }
 714        })
 715    }
 716
 717    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 718        let client = self.client.clone();
 719        let llm_token = self.llm_token.clone();
 720        let app_version = AppVersion::global(cx);
 721        let last_rejection = self.discarded_completions.last().cloned();
 722        let body = serde_json::to_string(&RejectEditPredictionsBody {
 723            rejections: self.discarded_completions.clone(),
 724        })
 725        .ok();
 726
 727        let Some(last_rejection) = last_rejection else {
 728            return Task::ready(anyhow::Ok(()));
 729        };
 730
 731        cx.spawn(async move |this, cx| {
 732            let http_client = client.http_client();
 733            let mut response = llm_token_retry(&llm_token, &client, |token| {
 734                let request_builder = http_client::Request::builder().method(Method::POST);
 735                let request_builder = request_builder.uri(
 736                    http_client
 737                        .build_zed_llm_url("/predict_edits/reject", &[])?
 738                        .as_ref(),
 739                );
 740                Ok(request_builder
 741                    .header("Content-Type", "application/json")
 742                    .header("Authorization", format!("Bearer {}", token))
 743                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 744                    .body(
 745                        body.as_ref()
 746                            .context("failed to serialize body")?
 747                            .clone()
 748                            .into(),
 749                    )?)
 750            })
 751            .await?;
 752
 753            if let Some(minimum_required_version) = response
 754                .headers()
 755                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 756                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
 757                && app_version < minimum_required_version
 758            {
 759                return Err(anyhow!(ZedUpdateRequiredError {
 760                    minimum_version: minimum_required_version
 761                }));
 762            }
 763
 764            if response.status().is_success() {
 765                this.update(cx, |this, _| {
 766                    if let Some(ix) = this
 767                        .discarded_completions
 768                        .iter()
 769                        .position(|rejection| rejection.request_id == last_rejection.request_id)
 770                    {
 771                        this.discarded_completions.drain(..ix + 1);
 772                    }
 773                })
 774            } else {
 775                let mut body = String::new();
 776                response.body_mut().read_to_string(&mut body).await?;
 777                Err(anyhow!(
 778                    "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
 779                    response.status(),
 780                    body
 781                ))
 782            }
 783        })
 784    }
 785
 786    fn process_completion_response(
 787        prediction_response: PredictEditsResponse,
 788        buffer: Entity<Buffer>,
 789        snapshot: &BufferSnapshot,
 790        editable_range: Range<usize>,
 791        cursor_offset: usize,
 792        path: Arc<Path>,
 793        input_outline: String,
 794        input_events: String,
 795        input_excerpt: String,
 796        buffer_snapshotted_at: Instant,
 797        cx: &AsyncApp,
 798    ) -> Task<Result<Option<EditPrediction>>> {
 799        let snapshot = snapshot.clone();
 800        let request_id = prediction_response.request_id;
 801        let output_excerpt = prediction_response.output_excerpt;
 802        cx.spawn(async move |cx| {
 803            let output_excerpt: Arc<str> = output_excerpt.into();
 804
 805            let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
 806                .background_spawn({
 807                    let output_excerpt = output_excerpt.clone();
 808                    let editable_range = editable_range.clone();
 809                    let snapshot = snapshot.clone();
 810                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 811                })
 812                .await?
 813                .into();
 814
 815            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 816                let edits = edits.clone();
 817                move |buffer, cx| {
 818                    let new_snapshot = buffer.snapshot();
 819                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
 820                        edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
 821                            .into();
 822                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 823                }
 824            })?
 825            else {
 826                return anyhow::Ok(None);
 827            };
 828
 829            let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
 830
 831            let edit_preview = edit_preview.await;
 832
 833            Ok(Some(EditPrediction {
 834                id: EditPredictionId(request_id),
 835                path,
 836                excerpt_range: editable_range,
 837                cursor_offset,
 838                edits,
 839                edit_preview,
 840                snapshot,
 841                input_outline: input_outline.into(),
 842                input_events: input_events.into(),
 843                input_excerpt: input_excerpt.into(),
 844                output_excerpt,
 845                buffer_snapshotted_at,
 846                response_received_at: Instant::now(),
 847            }))
 848        })
 849    }
 850
 851    fn parse_edits(
 852        output_excerpt: Arc<str>,
 853        editable_range: Range<usize>,
 854        snapshot: &BufferSnapshot,
 855    ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
 856        let content = output_excerpt.replace(CURSOR_MARKER, "");
 857
 858        let start_markers = content
 859            .match_indices(EDITABLE_REGION_START_MARKER)
 860            .collect::<Vec<_>>();
 861        anyhow::ensure!(
 862            start_markers.len() == 1,
 863            "expected exactly one start marker, found {}",
 864            start_markers.len()
 865        );
 866
 867        let end_markers = content
 868            .match_indices(EDITABLE_REGION_END_MARKER)
 869            .collect::<Vec<_>>();
 870        anyhow::ensure!(
 871            end_markers.len() == 1,
 872            "expected exactly one end marker, found {}",
 873            end_markers.len()
 874        );
 875
 876        let sof_markers = content
 877            .match_indices(START_OF_FILE_MARKER)
 878            .collect::<Vec<_>>();
 879        anyhow::ensure!(
 880            sof_markers.len() <= 1,
 881            "expected at most one start-of-file marker, found {}",
 882            sof_markers.len()
 883        );
 884
 885        let codefence_start = start_markers[0].0;
 886        let content = &content[codefence_start..];
 887
 888        let newline_ix = content.find('\n').context("could not find newline")?;
 889        let content = &content[newline_ix + 1..];
 890
 891        let codefence_end = content
 892            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 893            .context("could not find end marker")?;
 894        let new_text = &content[..codefence_end];
 895
 896        let old_text = snapshot
 897            .text_for_range(editable_range.clone())
 898            .collect::<String>();
 899
 900        Ok(Self::compute_edits(
 901            old_text,
 902            new_text,
 903            editable_range.start,
 904            snapshot,
 905        ))
 906    }
 907
 908    pub fn compute_edits(
 909        old_text: String,
 910        new_text: &str,
 911        offset: usize,
 912        snapshot: &BufferSnapshot,
 913    ) -> Vec<(Range<Anchor>, Arc<str>)> {
 914        text_diff(&old_text, new_text)
 915            .into_iter()
 916            .map(|(mut old_range, new_text)| {
 917                old_range.start += offset;
 918                old_range.end += offset;
 919
 920                let prefix_len = common_prefix(
 921                    snapshot.chars_for_range(old_range.clone()),
 922                    new_text.chars(),
 923                );
 924                old_range.start += prefix_len;
 925
 926                let suffix_len = common_prefix(
 927                    snapshot.reversed_chars_for_range(old_range.clone()),
 928                    new_text[prefix_len..].chars().rev(),
 929                );
 930                old_range.end = old_range.end.saturating_sub(suffix_len);
 931
 932                let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
 933                let range = if old_range.is_empty() {
 934                    let anchor = snapshot.anchor_after(old_range.start);
 935                    anchor..anchor
 936                } else {
 937                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 938                };
 939                (range, new_text)
 940            })
 941            .collect()
 942    }
 943
 944    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 945        self.rated_completions.contains(&completion_id)
 946    }
 947
 948    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 949        self.shown_completions.push_front(completion.clone());
 950        if self.shown_completions.len() > 50 {
 951            let completion = self.shown_completions.pop_back().unwrap();
 952            self.rated_completions.remove(&completion.id);
 953        }
 954        cx.notify();
 955    }
 956
 957    pub fn rate_completion(
 958        &mut self,
 959        completion: &EditPrediction,
 960        rating: EditPredictionRating,
 961        feedback: String,
 962        cx: &mut Context<Self>,
 963    ) {
 964        self.rated_completions.insert(completion.id);
 965        telemetry::event!(
 966            "Edit Prediction Rated",
 967            rating,
 968            input_events = completion.input_events,
 969            input_excerpt = completion.input_excerpt,
 970            input_outline = completion.input_outline,
 971            output_excerpt = completion.output_excerpt,
 972            feedback
 973        );
 974        self.client.telemetry().flush_events().detach();
 975        cx.notify();
 976    }
 977
 978    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 979        self.shown_completions.iter()
 980    }
 981
 982    pub fn shown_completions_len(&self) -> usize {
 983        self.shown_completions.len()
 984    }
 985
 986    fn report_changes_for_buffer(
 987        &mut self,
 988        buffer: &Entity<Buffer>,
 989        project: &Entity<Project>,
 990        cx: &mut Context<Self>,
 991    ) -> BufferSnapshot {
 992        let zeta_project = self.get_or_init_zeta_project(project, cx);
 993        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 994
 995        let new_snapshot = buffer.read(cx).snapshot();
 996        if new_snapshot.version != registered_buffer.snapshot.version {
 997            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 998            Self::push_event(
 999                zeta_project,
1000                Event::BufferChange {
1001                    old_snapshot,
1002                    new_snapshot: new_snapshot.clone(),
1003                    timestamp: Instant::now(),
1004                },
1005            );
1006        }
1007
1008        new_snapshot
1009    }
1010
1011    fn can_collect_file(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1012        self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx)
1013    }
1014
1015    fn can_collect_events(&self, events: &[Event], cx: &App) -> bool {
1016        if !self.data_collection_choice.is_enabled() {
1017            return false;
1018        }
1019        let mut last_checked_file = None;
1020        for event in events {
1021            match event {
1022                Event::BufferChange {
1023                    old_snapshot,
1024                    new_snapshot,
1025                    ..
1026                } => {
1027                    if let Some(old_file) = old_snapshot.file()
1028                        && let Some(new_file) = new_snapshot.file()
1029                    {
1030                        if let Some(last_checked_file) = last_checked_file
1031                            && Arc::ptr_eq(last_checked_file, old_file)
1032                            && Arc::ptr_eq(last_checked_file, new_file)
1033                        {
1034                            continue;
1035                        }
1036                        if !self.can_collect_file(old_file, cx) {
1037                            return false;
1038                        }
1039                        if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx)
1040                        {
1041                            return false;
1042                        }
1043                        last_checked_file = Some(new_file);
1044                    } else {
1045                        return false;
1046                    }
1047                }
1048            }
1049        }
1050        true
1051    }
1052
1053    fn is_file_open_source(&self, file: &Arc<dyn File>, cx: &App) -> bool {
1054        if !file.is_local() || file.is_private() {
1055            return false;
1056        }
1057        self.license_detection_watchers
1058            .get(&file.worktree_id(cx))
1059            .is_some_and(|watcher| watcher.is_project_open_source())
1060    }
1061
1062    fn load_data_collection_choice() -> DataCollectionChoice {
1063        let choice = KEY_VALUE_STORE
1064            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1065            .log_err()
1066            .flatten();
1067
1068        match choice.as_deref() {
1069            Some("true") => DataCollectionChoice::Enabled,
1070            Some("false") => DataCollectionChoice::Disabled,
1071            Some(_) => {
1072                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1073                DataCollectionChoice::NotAnswered
1074            }
1075            None => DataCollectionChoice::NotAnswered,
1076        }
1077    }
1078
1079    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1080        self.data_collection_choice = self.data_collection_choice.toggle();
1081        let new_choice = self.data_collection_choice;
1082        db::write_and_log(cx, move || {
1083            KEY_VALUE_STORE.write_kvp(
1084                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1085                new_choice.is_enabled().to_string(),
1086            )
1087        });
1088    }
1089
1090    fn discard_completion(
1091        &mut self,
1092        completion_id: EditPredictionId,
1093        was_shown: bool,
1094        cx: &mut Context<Self>,
1095    ) {
1096        self.discarded_completions.push(EditPredictionRejection {
1097            request_id: completion_id.to_string(),
1098            was_shown,
1099        });
1100
1101        let reached_request_limit =
1102            self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
1103        let discard_completions_tx = self.discard_completions_tx.clone();
1104        self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
1105            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
1106            if !reached_request_limit {
1107                cx.background_executor()
1108                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
1109                    .await;
1110            }
1111            discard_completions_tx.unbounded_send(()).log_err();
1112        }));
1113    }
1114}
1115
1116pub struct PerformPredictEditsParams {
1117    pub client: Arc<Client>,
1118    pub llm_token: LlmApiToken,
1119    pub app_version: Version,
1120    pub body: PredictEditsBody,
1121}
1122
1123#[derive(Error, Debug)]
1124#[error(
1125    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1126)]
1127pub struct ZedUpdateRequiredError {
1128    minimum_version: Version,
1129}
1130
1131fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1132    a.zip(b)
1133        .take_while(|(a, b)| a == b)
1134        .map(|(a, _)| a.len_utf8())
1135        .sum()
1136}
1137
1138fn git_info_for_file(
1139    project: &Entity<Project>,
1140    project_path: &ProjectPath,
1141    cx: &App,
1142) -> Option<PredictEditsGitInfo> {
1143    let git_store = project.read(cx).git_store().read(cx);
1144    if let Some((repository, _repo_path)) =
1145        git_store.repository_and_path_for_project_path(project_path, cx)
1146    {
1147        let repository = repository.read(cx);
1148        let head_sha = repository
1149            .head_commit
1150            .as_ref()
1151            .map(|head_commit| head_commit.sha.to_string());
1152        let remote_origin_url = repository.remote_origin_url.clone();
1153        let remote_upstream_url = repository.remote_upstream_url.clone();
1154        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1155            return None;
1156        }
1157        Some(PredictEditsGitInfo {
1158            head_sha,
1159            remote_origin_url,
1160            remote_upstream_url,
1161        })
1162    } else {
1163        None
1164    }
1165}
1166
1167pub struct GatherContextOutput {
1168    pub body: PredictEditsBody,
1169    pub editable_range: Range<usize>,
1170    pub included_events_count: usize,
1171}
1172
1173pub fn gather_context(
1174    full_path_str: String,
1175    snapshot: &BufferSnapshot,
1176    cursor_point: language::Point,
1177    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
1178    cx: &App,
1179) -> Task<Result<GatherContextOutput>> {
1180    cx.background_spawn({
1181        let snapshot = snapshot.clone();
1182        async move {
1183            let input_excerpt = excerpt_for_cursor_position(
1184                cursor_point,
1185                &full_path_str,
1186                &snapshot,
1187                MAX_REWRITE_TOKENS,
1188                MAX_CONTEXT_TOKENS,
1189            );
1190            let (input_events, included_events_count) = prompt_for_events();
1191            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1192
1193            let body = PredictEditsBody {
1194                input_events,
1195                input_excerpt: input_excerpt.prompt,
1196                can_collect_data: false,
1197                diagnostic_groups: None,
1198                git_info: None,
1199                outline: None,
1200                speculated_output: None,
1201            };
1202
1203            Ok(GatherContextOutput {
1204                body,
1205                editable_range,
1206                included_events_count,
1207            })
1208        }
1209    })
1210}
1211
1212fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) {
1213    let mut result = String::new();
1214    for (ix, event) in events.iter().rev().enumerate() {
1215        let event_string = event.to_prompt();
1216        let event_tokens = guess_token_count(event_string.len());
1217        if event_tokens > remaining_tokens {
1218            return (result, ix);
1219        }
1220
1221        if !result.is_empty() {
1222            result.insert_str(0, "\n\n");
1223        }
1224        result.insert_str(0, &event_string);
1225        remaining_tokens -= event_tokens;
1226    }
1227    return (result, events.len());
1228}
1229
1230struct RegisteredBuffer {
1231    snapshot: BufferSnapshot,
1232    _subscriptions: [gpui::Subscription; 2],
1233}
1234
1235#[derive(Clone)]
1236pub enum Event {
1237    BufferChange {
1238        old_snapshot: BufferSnapshot,
1239        new_snapshot: BufferSnapshot,
1240        timestamp: Instant,
1241    },
1242}
1243
1244impl Event {
1245    fn to_prompt(&self) -> String {
1246        match self {
1247            Event::BufferChange {
1248                old_snapshot,
1249                new_snapshot,
1250                ..
1251            } => {
1252                let mut prompt = String::new();
1253
1254                let old_path = old_snapshot
1255                    .file()
1256                    .map(|f| f.path().as_ref())
1257                    .unwrap_or(RelPath::unix("untitled").unwrap());
1258                let new_path = new_snapshot
1259                    .file()
1260                    .map(|f| f.path().as_ref())
1261                    .unwrap_or(RelPath::unix("untitled").unwrap());
1262                if old_path != new_path {
1263                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1264                }
1265
1266                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1267                if !diff.is_empty() {
1268                    write!(
1269                        prompt,
1270                        "User edited {:?}:\n```diff\n{}\n```",
1271                        new_path, diff
1272                    )
1273                    .unwrap();
1274                }
1275
1276                prompt
1277            }
1278        }
1279    }
1280}
1281
1282#[derive(Debug, Clone)]
1283struct CurrentEditPrediction {
1284    buffer_id: EntityId,
1285    completion: EditPrediction,
1286    was_shown: bool,
1287    was_accepted: bool,
1288}
1289
1290impl CurrentEditPrediction {
1291    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1292        if self.buffer_id != old_completion.buffer_id {
1293            return true;
1294        }
1295
1296        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1297            return true;
1298        };
1299        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1300            return false;
1301        };
1302
1303        if old_edits.len() == 1 && new_edits.len() == 1 {
1304            let (old_range, old_text) = &old_edits[0];
1305            let (new_range, new_text) = &new_edits[0];
1306            new_range == old_range && new_text.starts_with(old_text.as_ref())
1307        } else {
1308            true
1309        }
1310    }
1311}
1312
1313struct PendingCompletion {
1314    id: usize,
1315    task: Task<()>,
1316}
1317
1318#[derive(Debug, Clone, Copy)]
1319pub enum DataCollectionChoice {
1320    NotAnswered,
1321    Enabled,
1322    Disabled,
1323}
1324
1325impl DataCollectionChoice {
1326    pub fn is_enabled(self) -> bool {
1327        match self {
1328            Self::Enabled => true,
1329            Self::NotAnswered | Self::Disabled => false,
1330        }
1331    }
1332
1333    pub fn is_answered(self) -> bool {
1334        match self {
1335            Self::Enabled | Self::Disabled => true,
1336            Self::NotAnswered => false,
1337        }
1338    }
1339
1340    #[must_use]
1341    pub fn toggle(&self) -> DataCollectionChoice {
1342        match self {
1343            Self::Enabled => Self::Disabled,
1344            Self::Disabled => Self::Enabled,
1345            Self::NotAnswered => Self::Enabled,
1346        }
1347    }
1348}
1349
1350impl From<bool> for DataCollectionChoice {
1351    fn from(value: bool) -> Self {
1352        match value {
1353            true => DataCollectionChoice::Enabled,
1354            false => DataCollectionChoice::Disabled,
1355        }
1356    }
1357}
1358
1359async fn llm_token_retry(
1360    llm_token: &LlmApiToken,
1361    client: &Arc<Client>,
1362    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1363) -> Result<Response<AsyncBody>> {
1364    let mut did_retry = false;
1365    let http_client = client.http_client();
1366    let mut token = llm_token.acquire(client).await?;
1367    loop {
1368        let request = build_request(token.clone())?;
1369        let response = http_client.send(request).await?;
1370
1371        if !did_retry
1372            && !response.status().is_success()
1373            && response
1374                .headers()
1375                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1376                .is_some()
1377        {
1378            did_retry = true;
1379            token = llm_token.refresh(client).await?;
1380            continue;
1381        }
1382
1383        return Ok(response);
1384    }
1385}
1386
1387pub struct ZetaEditPredictionProvider {
1388    zeta: Entity<Zeta>,
1389    singleton_buffer: Option<Entity<Buffer>>,
1390    pending_completions: ArrayVec<PendingCompletion, 2>,
1391    canceled_completions: HashMap<usize, Task<()>>,
1392    next_pending_completion_id: usize,
1393    current_completion: Option<CurrentEditPrediction>,
1394    last_request_timestamp: Instant,
1395    project: Entity<Project>,
1396}
1397
1398impl ZetaEditPredictionProvider {
1399    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1400
1401    pub fn new(
1402        zeta: Entity<Zeta>,
1403        project: Entity<Project>,
1404        singleton_buffer: Option<Entity<Buffer>>,
1405        cx: &mut Context<Self>,
1406    ) -> Self {
1407        cx.on_release(|this, cx| {
1408            this.take_current_edit_prediction(cx);
1409        })
1410        .detach();
1411
1412        Self {
1413            zeta,
1414            singleton_buffer,
1415            pending_completions: ArrayVec::new(),
1416            canceled_completions: HashMap::default(),
1417            next_pending_completion_id: 0,
1418            current_completion: None,
1419            last_request_timestamp: Instant::now(),
1420            project,
1421        }
1422    }
1423
1424    fn take_current_edit_prediction(&mut self, cx: &mut App) {
1425        if let Some(completion) = self.current_completion.take() {
1426            if !completion.was_accepted {
1427                self.zeta.update(cx, |zeta, cx| {
1428                    zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
1429                });
1430            }
1431        }
1432    }
1433}
1434
1435impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1436    fn name() -> &'static str {
1437        "zed-predict"
1438    }
1439
1440    fn display_name() -> &'static str {
1441        "Zed's Edit Predictions"
1442    }
1443
1444    fn show_completions_in_menu() -> bool {
1445        true
1446    }
1447
1448    fn show_tab_accept_marker() -> bool {
1449        true
1450    }
1451
1452    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1453        if let Some(buffer) = &self.singleton_buffer
1454            && let Some(file) = buffer.read(cx).file()
1455        {
1456            let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx);
1457            if self.zeta.read(cx).data_collection_choice.is_enabled() {
1458                DataCollectionState::Enabled {
1459                    is_project_open_source,
1460                }
1461            } else {
1462                DataCollectionState::Disabled {
1463                    is_project_open_source,
1464                }
1465            }
1466        } else {
1467            return DataCollectionState::Disabled {
1468                is_project_open_source: false,
1469            };
1470        }
1471    }
1472
1473    fn toggle_data_collection(&mut self, cx: &mut App) {
1474        self.zeta
1475            .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx));
1476    }
1477
1478    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1479        self.zeta.read(cx).usage(cx)
1480    }
1481
1482    fn is_enabled(
1483        &self,
1484        _buffer: &Entity<Buffer>,
1485        _cursor_position: language::Anchor,
1486        _cx: &App,
1487    ) -> bool {
1488        true
1489    }
1490    fn is_refreshing(&self, _cx: &App) -> bool {
1491        !self.pending_completions.is_empty()
1492    }
1493
1494    fn refresh(
1495        &mut self,
1496        buffer: Entity<Buffer>,
1497        position: language::Anchor,
1498        _debounce: bool,
1499        cx: &mut Context<Self>,
1500    ) {
1501        if self.zeta.read(cx).update_required {
1502            return;
1503        }
1504
1505        if self
1506            .zeta
1507            .read(cx)
1508            .user_store
1509            .read_with(cx, |user_store, _cx| {
1510                user_store.account_too_young() || user_store.has_overdue_invoices()
1511            })
1512        {
1513            return;
1514        }
1515
1516        if let Some(current_completion) = self.current_completion.as_ref() {
1517            let snapshot = buffer.read(cx).snapshot();
1518            if current_completion
1519                .completion
1520                .interpolate(&snapshot)
1521                .is_some()
1522            {
1523                return;
1524            }
1525        }
1526
1527        let pending_completion_id = self.next_pending_completion_id;
1528        self.next_pending_completion_id += 1;
1529        let last_request_timestamp = self.last_request_timestamp;
1530
1531        let project = self.project.clone();
1532        let task = cx.spawn(async move |this, cx| {
1533            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1534                .checked_duration_since(Instant::now())
1535            {
1536                cx.background_executor().timer(timeout).await;
1537            }
1538
1539            let completion_request = this.update(cx, |this, cx| {
1540                this.last_request_timestamp = Instant::now();
1541                this.zeta.update(cx, |zeta, cx| {
1542                    zeta.request_completion(&project, &buffer, position, cx)
1543                })
1544            });
1545
1546            let completion = match completion_request {
1547                Ok(completion_request) => {
1548                    let completion_request = completion_request.await;
1549                    completion_request.map(|c| {
1550                        c.map(|completion| CurrentEditPrediction {
1551                            buffer_id: buffer.entity_id(),
1552                            completion,
1553                            was_shown: false,
1554                            was_accepted: false,
1555                        })
1556                    })
1557                }
1558                Err(error) => Err(error),
1559            };
1560
1561            let discarded = this
1562                .update(cx, |this, cx| {
1563                    if this
1564                        .pending_completions
1565                        .first()
1566                        .is_some_and(|completion| completion.id == pending_completion_id)
1567                    {
1568                        this.pending_completions.remove(0);
1569                    } else {
1570                        if let Some(discarded) = this.pending_completions.drain(..).next() {
1571                            this.canceled_completions
1572                                .insert(discarded.id, discarded.task);
1573                        }
1574                    }
1575
1576                    let canceled = this.canceled_completions.remove(&pending_completion_id);
1577
1578                    if canceled.is_some()
1579                        && let Ok(Some(new_completion)) = &completion
1580                    {
1581                        this.zeta.update(cx, |zeta, cx| {
1582                            zeta.discard_completion(new_completion.completion.id, false, cx);
1583                        });
1584                        return true;
1585                    }
1586
1587                    cx.notify();
1588                    false
1589                })
1590                .ok()
1591                .unwrap_or(true);
1592
1593            if discarded {
1594                return;
1595            }
1596
1597            let Some(new_completion) = completion
1598                .context("edit prediction failed")
1599                .log_err()
1600                .flatten()
1601            else {
1602                return;
1603            };
1604
1605            this.update(cx, |this, cx| {
1606                if let Some(old_completion) = this.current_completion.as_ref() {
1607                    let snapshot = buffer.read(cx).snapshot();
1608                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1609                        this.zeta.update(cx, |zeta, cx| {
1610                            zeta.completion_shown(&new_completion.completion, cx);
1611                        });
1612                        this.take_current_edit_prediction(cx);
1613                        this.current_completion = Some(new_completion);
1614                    }
1615                } else {
1616                    this.zeta.update(cx, |zeta, cx| {
1617                        zeta.completion_shown(&new_completion.completion, cx);
1618                    });
1619                    this.current_completion = Some(new_completion);
1620                }
1621
1622                cx.notify();
1623            })
1624            .ok();
1625        });
1626
1627        // We always maintain at most two pending completions. When we already
1628        // have two, we replace the newest one.
1629        if self.pending_completions.len() <= 1 {
1630            self.pending_completions.push(PendingCompletion {
1631                id: pending_completion_id,
1632                task,
1633            });
1634        } else if self.pending_completions.len() == 2 {
1635            if let Some(discarded) = self.pending_completions.pop() {
1636                self.canceled_completions
1637                    .insert(discarded.id, discarded.task);
1638            }
1639            self.pending_completions.push(PendingCompletion {
1640                id: pending_completion_id,
1641                task,
1642            });
1643        }
1644    }
1645
1646    fn cycle(
1647        &mut self,
1648        _buffer: Entity<Buffer>,
1649        _cursor_position: language::Anchor,
1650        _direction: edit_prediction::Direction,
1651        _cx: &mut Context<Self>,
1652    ) {
1653        // Right now we don't support cycling.
1654    }
1655
1656    fn accept(&mut self, cx: &mut Context<Self>) {
1657        let completion = self.current_completion.as_mut();
1658        if let Some(completion) = completion {
1659            completion.was_accepted = true;
1660            self.zeta
1661                .update(cx, |zeta, cx| {
1662                    zeta.accept_edit_prediction(completion.completion.id, cx)
1663                })
1664                .detach();
1665        }
1666        self.pending_completions.clear();
1667    }
1668
1669    fn discard(&mut self, cx: &mut Context<Self>) {
1670        self.pending_completions.clear();
1671        self.take_current_edit_prediction(cx);
1672    }
1673
1674    fn did_show(&mut self, _cx: &mut Context<Self>) {
1675        if let Some(current_completion) = self.current_completion.as_mut() {
1676            current_completion.was_shown = true;
1677        }
1678    }
1679
1680    fn suggest(
1681        &mut self,
1682        buffer: &Entity<Buffer>,
1683        cursor_position: language::Anchor,
1684        cx: &mut Context<Self>,
1685    ) -> Option<edit_prediction::EditPrediction> {
1686        let CurrentEditPrediction {
1687            buffer_id,
1688            completion,
1689            ..
1690        } = self.current_completion.as_mut()?;
1691
1692        // Invalidate previous completion if it was generated for a different buffer.
1693        if *buffer_id != buffer.entity_id() {
1694            self.take_current_edit_prediction(cx);
1695            return None;
1696        }
1697
1698        let buffer = buffer.read(cx);
1699        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1700            self.take_current_edit_prediction(cx);
1701            return None;
1702        };
1703
1704        let cursor_row = cursor_position.to_point(buffer).row;
1705        let (closest_edit_ix, (closest_edit_range, _)) =
1706            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1707                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1708                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1709                cmp::min(distance_from_start, distance_from_end)
1710            })?;
1711
1712        let mut edit_start_ix = closest_edit_ix;
1713        for (range, _) in edits[..edit_start_ix].iter().rev() {
1714            let distance_from_closest_edit =
1715                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1716            if distance_from_closest_edit <= 1 {
1717                edit_start_ix -= 1;
1718            } else {
1719                break;
1720            }
1721        }
1722
1723        let mut edit_end_ix = closest_edit_ix + 1;
1724        for (range, _) in &edits[edit_end_ix..] {
1725            let distance_from_closest_edit =
1726                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1727            if distance_from_closest_edit <= 1 {
1728                edit_end_ix += 1;
1729            } else {
1730                break;
1731            }
1732        }
1733
1734        Some(edit_prediction::EditPrediction::Local {
1735            id: Some(completion.id.to_string().into()),
1736            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1737            edit_preview: Some(completion.edit_preview.clone()),
1738        })
1739    }
1740}
1741
1742/// Typical number of string bytes per token for the purposes of limiting model input. This is
1743/// intentionally low to err on the side of underestimating limits.
1744const BYTES_PER_TOKEN_GUESS: usize = 3;
1745
1746fn guess_token_count(bytes: usize) -> usize {
1747    bytes / BYTES_PER_TOKEN_GUESS
1748}
1749
1750#[cfg(test)]
1751mod tests {
1752    use client::test::FakeServer;
1753    use clock::{FakeSystemClock, ReplicaId};
1754    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1755    use gpui::TestAppContext;
1756    use http_client::FakeHttpClient;
1757    use indoc::indoc;
1758    use language::Point;
1759    use parking_lot::Mutex;
1760    use serde_json::json;
1761    use settings::SettingsStore;
1762    use util::{path, rel_path::rel_path};
1763
1764    use super::*;
1765
1766    const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1767
1768    #[gpui::test]
1769    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1770        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1771        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1772            to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1773        });
1774
1775        let edit_preview = cx
1776            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1777            .await;
1778
1779        let completion = EditPrediction {
1780            edits,
1781            edit_preview,
1782            path: Path::new("").into(),
1783            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1784            id: EditPredictionId(Uuid::new_v4()),
1785            excerpt_range: 0..0,
1786            cursor_offset: 0,
1787            input_outline: "".into(),
1788            input_events: "".into(),
1789            input_excerpt: "".into(),
1790            output_excerpt: "".into(),
1791            buffer_snapshotted_at: Instant::now(),
1792            response_received_at: Instant::now(),
1793        };
1794
1795        cx.update(|cx| {
1796            assert_eq!(
1797                from_completion_edits(
1798                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1799                    &buffer,
1800                    cx
1801                ),
1802                vec![(2..5, "REM".into()), (9..11, "".into())]
1803            );
1804
1805            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1806            assert_eq!(
1807                from_completion_edits(
1808                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1809                    &buffer,
1810                    cx
1811                ),
1812                vec![(2..2, "REM".into()), (6..8, "".into())]
1813            );
1814
1815            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1816            assert_eq!(
1817                from_completion_edits(
1818                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1819                    &buffer,
1820                    cx
1821                ),
1822                vec![(2..5, "REM".into()), (9..11, "".into())]
1823            );
1824
1825            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1826            assert_eq!(
1827                from_completion_edits(
1828                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1829                    &buffer,
1830                    cx
1831                ),
1832                vec![(3..3, "EM".into()), (7..9, "".into())]
1833            );
1834
1835            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1836            assert_eq!(
1837                from_completion_edits(
1838                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1839                    &buffer,
1840                    cx
1841                ),
1842                vec![(4..4, "M".into()), (8..10, "".into())]
1843            );
1844
1845            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1846            assert_eq!(
1847                from_completion_edits(
1848                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1849                    &buffer,
1850                    cx
1851                ),
1852                vec![(9..11, "".into())]
1853            );
1854
1855            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1856            assert_eq!(
1857                from_completion_edits(
1858                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1859                    &buffer,
1860                    cx
1861                ),
1862                vec![(4..4, "M".into()), (8..10, "".into())]
1863            );
1864
1865            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1866            assert_eq!(
1867                from_completion_edits(
1868                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1869                    &buffer,
1870                    cx
1871                ),
1872                vec![(4..4, "M".into())]
1873            );
1874
1875            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1876            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1877        })
1878    }
1879
1880    #[gpui::test]
1881    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1882        init_test(cx);
1883
1884        assert_eq!(
1885            apply_edit_prediction(
1886                indoc! {"
1887                    fn main() {
1888                        let word_1 = \"lorem\";
1889                        let range = word.len()..word.len();
1890                    }
1891                "},
1892                indoc! {"
1893                    <|editable_region_start|>
1894                    fn main() {
1895                        let word_1 = \"lorem\";
1896                        let range = word_1.len()..word_1.len();
1897                    }
1898
1899                    <|editable_region_end|>
1900                "},
1901                cx,
1902            )
1903            .await,
1904            indoc! {"
1905                fn main() {
1906                    let word_1 = \"lorem\";
1907                    let range = word_1.len()..word_1.len();
1908                }
1909            "},
1910        );
1911
1912        assert_eq!(
1913            apply_edit_prediction(
1914                indoc! {"
1915                    fn main() {
1916                        let story = \"the quick\"
1917                    }
1918                "},
1919                indoc! {"
1920                    <|editable_region_start|>
1921                    fn main() {
1922                        let story = \"the quick brown fox jumps over the lazy dog\";
1923                    }
1924
1925                    <|editable_region_end|>
1926                "},
1927                cx,
1928            )
1929            .await,
1930            indoc! {"
1931                fn main() {
1932                    let story = \"the quick brown fox jumps over the lazy dog\";
1933                }
1934            "},
1935        );
1936    }
1937
1938    #[gpui::test]
1939    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1940        init_test(cx);
1941
1942        let buffer_content = "lorem\n";
1943        let completion_response = indoc! {"
1944            ```animals.js
1945            <|start_of_file|>
1946            <|editable_region_start|>
1947            lorem
1948            ipsum
1949            <|editable_region_end|>
1950            ```"};
1951
1952        assert_eq!(
1953            apply_edit_prediction(buffer_content, completion_response, cx).await,
1954            "lorem\nipsum"
1955        );
1956    }
1957
1958    #[gpui::test]
1959    async fn test_can_collect_data(cx: &mut TestAppContext) {
1960        init_test(cx);
1961
1962        let fs = project::FakeFs::new(cx.executor());
1963        fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1964            .await;
1965
1966        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1967        let buffer = project
1968            .update(cx, |project, cx| {
1969                project.open_local_buffer(path!("/project/src/main.rs"), cx)
1970            })
1971            .await
1972            .unwrap();
1973
1974        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
1975        zeta.update(cx, |zeta, _cx| {
1976            zeta.data_collection_choice = DataCollectionChoice::Enabled
1977        });
1978
1979        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1980        assert_eq!(
1981            captured_request.lock().clone().unwrap().can_collect_data,
1982            true
1983        );
1984
1985        zeta.update(cx, |zeta, _cx| {
1986            zeta.data_collection_choice = DataCollectionChoice::Disabled
1987        });
1988
1989        run_edit_prediction(&buffer, &project, &zeta, cx).await;
1990        assert_eq!(
1991            captured_request.lock().clone().unwrap().can_collect_data,
1992            false
1993        );
1994    }
1995
1996    #[gpui::test]
1997    async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1998        init_test(cx);
1999
2000        let fs = project::FakeFs::new(cx.executor());
2001        let project = Project::test(fs.clone(), [], cx).await;
2002
2003        let buffer = cx.new(|_cx| {
2004            Buffer::remote(
2005                language::BufferId::new(1).unwrap(),
2006                ReplicaId::new(1),
2007                language::Capability::ReadWrite,
2008                "fn main() {\n    println!(\"Hello\");\n}",
2009            )
2010        });
2011
2012        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2013        zeta.update(cx, |zeta, _cx| {
2014            zeta.data_collection_choice = DataCollectionChoice::Enabled
2015        });
2016
2017        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2018        assert_eq!(
2019            captured_request.lock().clone().unwrap().can_collect_data,
2020            false
2021        );
2022    }
2023
2024    #[gpui::test]
2025    async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2026        init_test(cx);
2027
2028        let fs = project::FakeFs::new(cx.executor());
2029        fs.insert_tree(
2030            path!("/project"),
2031            json!({
2032                "LICENSE": BSD_0_TXT,
2033                ".env": "SECRET_KEY=secret"
2034            }),
2035        )
2036        .await;
2037
2038        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2039        let buffer = project
2040            .update(cx, |project, cx| {
2041                project.open_local_buffer("/project/.env", cx)
2042            })
2043            .await
2044            .unwrap();
2045
2046        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2047        zeta.update(cx, |zeta, _cx| {
2048            zeta.data_collection_choice = DataCollectionChoice::Enabled
2049        });
2050
2051        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2052        assert_eq!(
2053            captured_request.lock().clone().unwrap().can_collect_data,
2054            false
2055        );
2056    }
2057
2058    #[gpui::test]
2059    async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2060        init_test(cx);
2061
2062        let fs = project::FakeFs::new(cx.executor());
2063        let project = Project::test(fs.clone(), [], cx).await;
2064        let buffer = cx.new(|cx| Buffer::local("", cx));
2065
2066        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2067        zeta.update(cx, |zeta, _cx| {
2068            zeta.data_collection_choice = DataCollectionChoice::Enabled
2069        });
2070
2071        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2072        assert_eq!(
2073            captured_request.lock().clone().unwrap().can_collect_data,
2074            false
2075        );
2076    }
2077
2078    #[gpui::test]
2079    async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2080        init_test(cx);
2081
2082        let fs = project::FakeFs::new(cx.executor());
2083        fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2084            .await;
2085
2086        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2087        let buffer = project
2088            .update(cx, |project, cx| {
2089                project.open_local_buffer("/project/main.rs", cx)
2090            })
2091            .await
2092            .unwrap();
2093
2094        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2095        zeta.update(cx, |zeta, _cx| {
2096            zeta.data_collection_choice = DataCollectionChoice::Enabled
2097        });
2098
2099        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2100        assert_eq!(
2101            captured_request.lock().clone().unwrap().can_collect_data,
2102            false
2103        );
2104    }
2105
2106    #[gpui::test]
2107    async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2108        init_test(cx);
2109
2110        let fs = project::FakeFs::new(cx.executor());
2111        fs.insert_tree(
2112            path!("/open_source_worktree"),
2113            json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2114        )
2115        .await;
2116        fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2117            .await;
2118
2119        let project = Project::test(
2120            fs.clone(),
2121            [
2122                path!("/open_source_worktree").as_ref(),
2123                path!("/closed_source_worktree").as_ref(),
2124            ],
2125            cx,
2126        )
2127        .await;
2128        let buffer = project
2129            .update(cx, |project, cx| {
2130                project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2131            })
2132            .await
2133            .unwrap();
2134
2135        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2136        zeta.update(cx, |zeta, _cx| {
2137            zeta.data_collection_choice = DataCollectionChoice::Enabled
2138        });
2139
2140        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2141        assert_eq!(
2142            captured_request.lock().clone().unwrap().can_collect_data,
2143            true
2144        );
2145
2146        let closed_source_file = project
2147            .update(cx, |project, cx| {
2148                let worktree2 = project
2149                    .worktree_for_root_name("closed_source_worktree", cx)
2150                    .unwrap();
2151                worktree2.update(cx, |worktree2, cx| {
2152                    worktree2.load_file(rel_path("main.rs"), cx)
2153                })
2154            })
2155            .await
2156            .unwrap()
2157            .file;
2158
2159        buffer.update(cx, |buffer, cx| {
2160            buffer.file_updated(closed_source_file, cx);
2161        });
2162
2163        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2164        assert_eq!(
2165            captured_request.lock().clone().unwrap().can_collect_data,
2166            false
2167        );
2168    }
2169
2170    #[gpui::test]
2171    async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2172        init_test(cx);
2173
2174        let fs = project::FakeFs::new(cx.executor());
2175        fs.insert_tree(
2176            path!("/worktree1"),
2177            json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2178        )
2179        .await;
2180        fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2181            .await;
2182
2183        let project = Project::test(
2184            fs.clone(),
2185            [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2186            cx,
2187        )
2188        .await;
2189        let buffer = project
2190            .update(cx, |project, cx| {
2191                project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2192            })
2193            .await
2194            .unwrap();
2195        let private_buffer = project
2196            .update(cx, |project, cx| {
2197                project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2198            })
2199            .await
2200            .unwrap();
2201
2202        let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
2203        zeta.update(cx, |zeta, _cx| {
2204            zeta.data_collection_choice = DataCollectionChoice::Enabled
2205        });
2206
2207        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2208        assert_eq!(
2209            captured_request.lock().clone().unwrap().can_collect_data,
2210            true
2211        );
2212
2213        // this has a side effect of registering the buffer to watch for edits
2214        run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
2215        assert_eq!(
2216            captured_request.lock().clone().unwrap().can_collect_data,
2217            false
2218        );
2219
2220        private_buffer.update(cx, |private_buffer, cx| {
2221            private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2222        });
2223
2224        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2225        assert_eq!(
2226            captured_request.lock().clone().unwrap().can_collect_data,
2227            false
2228        );
2229
2230        // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2231        // included
2232        buffer.update(cx, |buffer, cx| {
2233            buffer.edit(
2234                [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))],
2235                None,
2236                cx,
2237            );
2238        });
2239
2240        run_edit_prediction(&buffer, &project, &zeta, cx).await;
2241        assert_eq!(
2242            captured_request.lock().clone().unwrap().can_collect_data,
2243            true
2244        );
2245    }
2246
2247    fn init_test(cx: &mut TestAppContext) {
2248        cx.update(|cx| {
2249            let settings_store = SettingsStore::test(cx);
2250            cx.set_global(settings_store);
2251        });
2252    }
2253
2254    async fn apply_edit_prediction(
2255        buffer_content: &str,
2256        completion_response: &str,
2257        cx: &mut TestAppContext,
2258    ) -> String {
2259        let fs = project::FakeFs::new(cx.executor());
2260        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2261        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2262        let (zeta, _, response) = make_test_zeta(&project, cx).await;
2263        *response.lock() = completion_response.to_string();
2264        let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
2265        buffer.update(cx, |buffer, cx| {
2266            buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2267        });
2268        buffer.read_with(cx, |buffer, _| buffer.text())
2269    }
2270
2271    async fn run_edit_prediction(
2272        buffer: &Entity<Buffer>,
2273        project: &Entity<Project>,
2274        zeta: &Entity<Zeta>,
2275        cx: &mut TestAppContext,
2276    ) -> EditPrediction {
2277        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2278        zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
2279        cx.background_executor.run_until_parked();
2280        let completion_task = zeta.update(cx, |zeta, cx| {
2281            zeta.request_completion(&project, buffer, cursor, cx)
2282        });
2283        completion_task.await.unwrap().unwrap()
2284    }
2285
2286    async fn make_test_zeta(
2287        project: &Entity<Project>,
2288        cx: &mut TestAppContext,
2289    ) -> (
2290        Entity<Zeta>,
2291        Arc<Mutex<Option<PredictEditsBody>>>,
2292        Arc<Mutex<String>>,
2293    ) {
2294        let default_response = indoc! {"
2295            ```main.rs
2296            <|start_of_file|>
2297            <|editable_region_start|>
2298            hello world
2299            <|editable_region_end|>
2300            ```"
2301        };
2302        let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2303        let completion_response: Arc<Mutex<String>> =
2304            Arc::new(Mutex::new(default_response.to_string()));
2305        let http_client = FakeHttpClient::create({
2306            let captured_request = captured_request.clone();
2307            let completion_response = completion_response.clone();
2308            move |req| {
2309                let captured_request = captured_request.clone();
2310                let completion_response = completion_response.clone();
2311                async move {
2312                    match (req.method(), req.uri().path()) {
2313                        (&Method::POST, "/client/llm_tokens") => {
2314                            Ok(http_client::Response::builder()
2315                                .status(200)
2316                                .body(
2317                                    serde_json::to_string(&CreateLlmTokenResponse {
2318                                        token: LlmToken("the-llm-token".to_string()),
2319                                    })
2320                                    .unwrap()
2321                                    .into(),
2322                                )
2323                                .unwrap())
2324                        }
2325                        (&Method::POST, "/predict_edits/v2") => {
2326                            let mut request_body = String::new();
2327                            req.into_body().read_to_string(&mut request_body).await?;
2328                            *captured_request.lock() =
2329                                Some(serde_json::from_str(&request_body).unwrap());
2330                            Ok(http_client::Response::builder()
2331                                .status(200)
2332                                .body(
2333                                    serde_json::to_string(&PredictEditsResponse {
2334                                        request_id: Uuid::new_v4().to_string(),
2335                                        output_excerpt: completion_response.lock().clone(),
2336                                    })
2337                                    .unwrap()
2338                                    .into(),
2339                                )
2340                                .unwrap())
2341                        }
2342                        _ => Ok(http_client::Response::builder()
2343                            .status(404)
2344                            .body("Not Found".into())
2345                            .unwrap()),
2346                    }
2347                }
2348            }
2349        });
2350
2351        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2352        cx.update(|cx| {
2353            RefreshLlmTokenListener::register(client.clone(), cx);
2354        });
2355        let _server = FakeServer::for_client(42, &client, cx).await;
2356
2357        let zeta = cx.new(|cx| {
2358            let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
2359
2360            let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2361            for worktree in worktrees {
2362                let worktree_id = worktree.read(cx).id();
2363                zeta.license_detection_watchers
2364                    .entry(worktree_id)
2365                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2366            }
2367
2368            zeta
2369        });
2370
2371        (zeta, captured_request, completion_response)
2372    }
2373
2374    fn to_completion_edits(
2375        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2376        buffer: &Entity<Buffer>,
2377        cx: &App,
2378    ) -> Vec<(Range<Anchor>, Arc<str>)> {
2379        let buffer = buffer.read(cx);
2380        iterator
2381            .into_iter()
2382            .map(|(range, text)| {
2383                (
2384                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2385                    text,
2386                )
2387            })
2388            .collect()
2389    }
2390
2391    fn from_completion_edits(
2392        editor_edits: &[(Range<Anchor>, Arc<str>)],
2393        buffer: &Entity<Buffer>,
2394        cx: &App,
2395    ) -> Vec<(Range<usize>, Arc<str>)> {
2396        let buffer = buffer.read(cx);
2397        editor_edits
2398            .iter()
2399            .map(|(range, text)| {
2400                (
2401                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2402                    text.clone(),
2403                )
2404            })
2405            .collect()
2406    }
2407
2408    #[ctor::ctor]
2409    fn init_logger() {
2410        zlog::init_test();
2411    }
2412}