zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::{Dismissable, KEY_VALUE_STORE};
  11use edit_prediction::DataCollectionState;
  12pub use init::*;
  13use license_detection::LicenseDetectionWatcher;
  14pub use rate_completion_modal::*;
  15
  16use anyhow::{Context as _, Result, anyhow};
  17use arrayvec::ArrayVec;
  18use client::{Client, EditPredictionUsage, UserStore};
  19use cloud_llm_client::{
  20    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  21    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
  22};
  23use collections::{HashMap, HashSet, VecDeque};
  24use futures::AsyncReadExt;
  25use gpui::{
  26    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  27    SharedString, Subscription, Task, actions,
  28};
  29use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  30use input_excerpt::excerpt_for_cursor_position;
  31use language::{
  32    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  33};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use project::{Project, ProjectPath};
  36use release_channel::AppVersion;
  37use settings::WorktreeId;
  38use std::collections::hash_map;
  39use std::mem;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    ops::Range,
  46    path::Path,
  47    rc::Rc,
  48    sync::Arc,
  49    time::{Duration, Instant},
  50};
  51use telemetry_events::EditPredictionRating;
  52use thiserror::Error;
  53use util::ResultExt;
  54use uuid::Uuid;
  55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  56use worktree::Worktree;
  57
  58const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  59const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  60const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  61const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  62const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  63const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  64
  65const MAX_CONTEXT_TOKENS: usize = 150;
  66const MAX_REWRITE_TOKENS: usize = 350;
  67const MAX_EVENT_TOKENS: usize = 500;
  68const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  69
  70/// Maximum number of events to track.
  71const MAX_EVENT_COUNT: usize = 16;
  72
  73actions!(
  74    edit_prediction,
  75    [
  76        /// Clears the edit prediction history.
  77        ClearHistory
  78    ]
  79);
  80
  81#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  82pub struct EditPredictionId(Uuid);
  83
  84impl From<EditPredictionId> for gpui::ElementId {
  85    fn from(value: EditPredictionId) -> Self {
  86        gpui::ElementId::Uuid(value.0)
  87    }
  88}
  89
  90impl std::fmt::Display for EditPredictionId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96struct ZedPredictUpsell;
  97
  98impl Dismissable for ZedPredictUpsell {
  99    const KEY: &'static str = "dismissed-edit-predict-upsell";
 100
 101    fn dismissed() -> bool {
 102        // To make this backwards compatible with older versions of Zed, we
 103        // check if the user has seen the previous Edit Prediction Onboarding
 104        // before, by checking the data collection choice which was written to
 105        // the database once the user clicked on "Accept and Enable"
 106        if KEY_VALUE_STORE
 107            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 108            .log_err()
 109            .is_some_and(|s| s.is_some())
 110        {
 111            return true;
 112        }
 113
 114        KEY_VALUE_STORE
 115            .read_kvp(Self::KEY)
 116            .log_err()
 117            .is_some_and(|s| s.is_some())
 118    }
 119}
 120
 121pub fn should_show_upsell_modal() -> bool {
 122    !ZedPredictUpsell::dismissed()
 123}
 124
 125#[derive(Clone)]
 126struct ZetaGlobal(Entity<Zeta>);
 127
 128impl Global for ZetaGlobal {}
 129
 130#[derive(Clone)]
 131pub struct EditPrediction {
 132    id: EditPredictionId,
 133    path: Arc<Path>,
 134    excerpt_range: Range<usize>,
 135    cursor_offset: usize,
 136    edits: Arc<[(Range<Anchor>, String)]>,
 137    snapshot: BufferSnapshot,
 138    edit_preview: EditPreview,
 139    input_outline: Arc<str>,
 140    input_events: Arc<str>,
 141    input_excerpt: Arc<str>,
 142    output_excerpt: Arc<str>,
 143    buffer_snapshotted_at: Instant,
 144    response_received_at: Instant,
 145}
 146
 147impl EditPrediction {
 148    fn latency(&self) -> Duration {
 149        self.response_received_at
 150            .duration_since(self.buffer_snapshotted_at)
 151    }
 152
 153    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 154        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 155    }
 156}
 157
 158fn interpolate(
 159    old_snapshot: &BufferSnapshot,
 160    new_snapshot: &BufferSnapshot,
 161    current_edits: Arc<[(Range<Anchor>, String)]>,
 162) -> Option<Vec<(Range<Anchor>, String)>> {
 163    let mut edits = Vec::new();
 164
 165    let mut model_edits = current_edits.iter().peekable();
 166    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 167        while let Some((model_old_range, _)) = model_edits.peek() {
 168            let model_old_range = model_old_range.to_offset(old_snapshot);
 169            if model_old_range.end < user_edit.old.start {
 170                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 171                edits.push((model_old_range.clone(), model_new_text.clone()));
 172            } else {
 173                break;
 174            }
 175        }
 176
 177        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 178            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 179            if user_edit.old == model_old_offset_range {
 180                let user_new_text = new_snapshot
 181                    .text_for_range(user_edit.new.clone())
 182                    .collect::<String>();
 183
 184                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 185                    if !model_suffix.is_empty() {
 186                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 187                        edits.push((anchor..anchor, model_suffix.to_string()));
 188                    }
 189
 190                    model_edits.next();
 191                    continue;
 192                }
 193            }
 194        }
 195
 196        return None;
 197    }
 198
 199    edits.extend(model_edits.cloned());
 200
 201    if edits.is_empty() { None } else { Some(edits) }
 202}
 203
 204impl std::fmt::Debug for EditPrediction {
 205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 206        f.debug_struct("EditPrediction")
 207            .field("id", &self.id)
 208            .field("path", &self.path)
 209            .field("edits", &self.edits)
 210            .finish_non_exhaustive()
 211    }
 212}
 213
 214pub struct Zeta {
 215    projects: HashMap<EntityId, ZetaProject>,
 216    client: Arc<Client>,
 217    shown_completions: VecDeque<EditPrediction>,
 218    rated_completions: HashSet<EditPredictionId>,
 219    data_collection_choice: Entity<DataCollectionChoice>,
 220    llm_token: LlmApiToken,
 221    _llm_token_subscription: Subscription,
 222    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 223    update_required: bool,
 224    user_store: Entity<UserStore>,
 225    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 226}
 227
 228struct ZetaProject {
 229    events: VecDeque<Event>,
 230    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 231}
 232
 233impl Zeta {
 234    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 235        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 236    }
 237
 238    pub fn register(
 239        worktree: Option<Entity<Worktree>>,
 240        client: Arc<Client>,
 241        user_store: Entity<UserStore>,
 242        cx: &mut App,
 243    ) -> Entity<Self> {
 244        let this = Self::global(cx).unwrap_or_else(|| {
 245            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 246            cx.set_global(ZetaGlobal(entity.clone()));
 247            entity
 248        });
 249
 250        this.update(cx, move |this, cx| {
 251            if let Some(worktree) = worktree {
 252                let worktree_id = worktree.read(cx).id();
 253                this.license_detection_watchers
 254                    .entry(worktree_id)
 255                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 256            }
 257        });
 258
 259        this
 260    }
 261
 262    pub fn clear_history(&mut self) {
 263        for zeta_project in self.projects.values_mut() {
 264            zeta_project.events.clear();
 265        }
 266    }
 267
 268    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 269        self.user_store.read(cx).edit_prediction_usage()
 270    }
 271
 272    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 273        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 274
 275        let data_collection_choice = Self::load_data_collection_choices();
 276        let data_collection_choice = cx.new(|_| data_collection_choice);
 277
 278        Self {
 279            projects: HashMap::default(),
 280            client,
 281            shown_completions: VecDeque::new(),
 282            rated_completions: HashSet::default(),
 283            data_collection_choice,
 284            llm_token: LlmApiToken::default(),
 285            _llm_token_subscription: cx.subscribe(
 286                &refresh_llm_token_listener,
 287                |this, _listener, _event, cx| {
 288                    let client = this.client.clone();
 289                    let llm_token = this.llm_token.clone();
 290                    cx.spawn(async move |_this, _cx| {
 291                        llm_token.refresh(&client).await?;
 292                        anyhow::Ok(())
 293                    })
 294                    .detach_and_log_err(cx);
 295                },
 296            ),
 297            update_required: false,
 298            license_detection_watchers: HashMap::default(),
 299            user_store,
 300        }
 301    }
 302
 303    fn get_or_init_zeta_project(
 304        &mut self,
 305        project: &Entity<Project>,
 306        cx: &mut Context<Self>,
 307    ) -> &mut ZetaProject {
 308        let project_id = project.entity_id();
 309        match self.projects.entry(project_id) {
 310            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 311            hash_map::Entry::Vacant(entry) => {
 312                cx.observe_release(project, move |this, _, _cx| {
 313                    this.projects.remove(&project_id);
 314                })
 315                .detach();
 316                entry.insert(ZetaProject {
 317                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 318                    registered_buffers: HashMap::default(),
 319                })
 320            }
 321        }
 322    }
 323
 324    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 325        let events = &mut zeta_project.events;
 326
 327        if let Some(Event::BufferChange {
 328            new_snapshot: last_new_snapshot,
 329            timestamp: last_timestamp,
 330            ..
 331        }) = events.back_mut()
 332        {
 333            // Coalesce edits for the same buffer when they happen one after the other.
 334            let Event::BufferChange {
 335                old_snapshot,
 336                new_snapshot,
 337                timestamp,
 338            } = &event;
 339
 340            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 341                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 342                && old_snapshot.version == last_new_snapshot.version
 343            {
 344                *last_new_snapshot = new_snapshot.clone();
 345                *last_timestamp = *timestamp;
 346                return;
 347            }
 348        }
 349
 350        if events.len() >= MAX_EVENT_COUNT {
 351            // These are halved instead of popping to improve prompt caching.
 352            events.drain(..MAX_EVENT_COUNT / 2);
 353        }
 354
 355        events.push_back(event);
 356    }
 357
 358    pub fn register_buffer(
 359        &mut self,
 360        buffer: &Entity<Buffer>,
 361        project: &Entity<Project>,
 362        cx: &mut Context<Self>,
 363    ) {
 364        let zeta_project = self.get_or_init_zeta_project(project, cx);
 365        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 366    }
 367
 368    fn register_buffer_impl<'a>(
 369        zeta_project: &'a mut ZetaProject,
 370        buffer: &Entity<Buffer>,
 371        project: &Entity<Project>,
 372        cx: &mut Context<Self>,
 373    ) -> &'a mut RegisteredBuffer {
 374        let buffer_id = buffer.entity_id();
 375        match zeta_project.registered_buffers.entry(buffer_id) {
 376            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 377            hash_map::Entry::Vacant(entry) => {
 378                let snapshot = buffer.read(cx).snapshot();
 379                let project_entity_id = project.entity_id();
 380                entry.insert(RegisteredBuffer {
 381                    snapshot,
 382                    _subscriptions: [
 383                        cx.subscribe(buffer, {
 384                            let project = project.downgrade();
 385                            move |this, buffer, event, cx| {
 386                                if let language::BufferEvent::Edited = event
 387                                    && let Some(project) = project.upgrade()
 388                                {
 389                                    this.report_changes_for_buffer(&buffer, &project, cx);
 390                                }
 391                            }
 392                        }),
 393                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 394                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 395                            else {
 396                                return;
 397                            };
 398                            zeta_project.registered_buffers.remove(&buffer_id);
 399                        }),
 400                    ],
 401                })
 402            }
 403        }
 404    }
 405
 406    fn request_completion_impl<F, R>(
 407        &mut self,
 408        project: &Entity<Project>,
 409        buffer: &Entity<Buffer>,
 410        cursor: language::Anchor,
 411        can_collect_data: bool,
 412        cx: &mut Context<Self>,
 413        perform_predict_edits: F,
 414    ) -> Task<Result<Option<EditPrediction>>>
 415    where
 416        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 417        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 418            + Send
 419            + 'static,
 420    {
 421        let buffer = buffer.clone();
 422        let buffer_snapshotted_at = Instant::now();
 423        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
 424        let zeta = cx.entity();
 425        let events = self.get_or_init_zeta_project(project, cx).events.clone();
 426        let client = self.client.clone();
 427        let llm_token = self.llm_token.clone();
 428        let app_version = AppVersion::global(cx);
 429
 430        let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) {
 431            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 432        } else {
 433            None
 434        };
 435
 436        let full_path: Arc<Path> = snapshot
 437            .file()
 438            .map(|f| Arc::from(f.full_path(cx).as_path()))
 439            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 440        let full_path_str = full_path.to_string_lossy().to_string();
 441        let cursor_point = cursor.to_point(&snapshot);
 442        let cursor_offset = cursor_point.to_offset(&snapshot);
 443        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 444        let gather_task = gather_context(
 445            project,
 446            full_path_str,
 447            &snapshot,
 448            cursor_point,
 449            make_events_prompt,
 450            can_collect_data,
 451            git_info,
 452            cx,
 453        );
 454
 455        cx.spawn(async move |this, cx| {
 456            let GatherContextOutput {
 457                body,
 458                editable_range,
 459            } = gather_task.await?;
 460            let done_gathering_context_at = Instant::now();
 461
 462            log::debug!(
 463                "Events:\n{}\nExcerpt:\n{:?}",
 464                body.input_events,
 465                body.input_excerpt
 466            );
 467
 468            let input_outline = body.outline.clone().unwrap_or_default();
 469            let input_events = body.input_events.clone();
 470            let input_excerpt = body.input_excerpt.clone();
 471
 472            let response = perform_predict_edits(PerformPredictEditsParams {
 473                client,
 474                llm_token,
 475                app_version,
 476                body,
 477            })
 478            .await;
 479            let (response, usage) = match response {
 480                Ok(response) => response,
 481                Err(err) => {
 482                    if err.is::<ZedUpdateRequiredError>() {
 483                        cx.update(|cx| {
 484                            zeta.update(cx, |zeta, _cx| {
 485                                zeta.update_required = true;
 486                            });
 487
 488                            let error_message: SharedString = err.to_string().into();
 489                            show_app_notification(
 490                                NotificationId::unique::<ZedUpdateRequiredError>(),
 491                                cx,
 492                                move |cx| {
 493                                    cx.new(|cx| {
 494                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 495                                            .with_link_button(
 496                                                "Update Zed",
 497                                                "https://zed.dev/releases",
 498                                            )
 499                                    })
 500                                },
 501                            );
 502                        })
 503                        .ok();
 504                    }
 505
 506                    return Err(err);
 507                }
 508            };
 509
 510            let received_response_at = Instant::now();
 511            log::debug!("completion response: {}", &response.output_excerpt);
 512
 513            if let Some(usage) = usage {
 514                this.update(cx, |this, cx| {
 515                    this.user_store.update(cx, |user_store, cx| {
 516                        user_store.update_edit_prediction_usage(usage, cx);
 517                    });
 518                })
 519                .ok();
 520            }
 521
 522            let edit_prediction = Self::process_completion_response(
 523                response,
 524                buffer,
 525                &snapshot,
 526                editable_range,
 527                cursor_offset,
 528                full_path,
 529                input_outline,
 530                input_events,
 531                input_excerpt,
 532                buffer_snapshotted_at,
 533                cx,
 534            )
 535            .await;
 536
 537            let finished_at = Instant::now();
 538
 539            // record latency for ~1% of requests
 540            if rand::random::<u8>() <= 2 {
 541                telemetry::event!(
 542                    "Edit Prediction Request",
 543                    context_latency = done_gathering_context_at
 544                        .duration_since(buffer_snapshotted_at)
 545                        .as_millis(),
 546                    request_latency = received_response_at
 547                        .duration_since(done_gathering_context_at)
 548                        .as_millis(),
 549                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 550                );
 551            }
 552
 553            edit_prediction
 554        })
 555    }
 556
 557    #[cfg(any(test, feature = "test-support"))]
 558    pub fn fake_completion(
 559        &mut self,
 560        project: &Entity<Project>,
 561        buffer: &Entity<Buffer>,
 562        position: language::Anchor,
 563        response: PredictEditsResponse,
 564        cx: &mut Context<Self>,
 565    ) -> Task<Result<Option<EditPrediction>>> {
 566        use std::future::ready;
 567
 568        self.request_completion_impl(project, buffer, position, false, cx, |_params| {
 569            ready(Ok((response, None)))
 570        })
 571    }
 572
 573    pub fn request_completion(
 574        &mut self,
 575        project: &Entity<Project>,
 576        buffer: &Entity<Buffer>,
 577        position: language::Anchor,
 578        can_collect_data: bool,
 579        cx: &mut Context<Self>,
 580    ) -> Task<Result<Option<EditPrediction>>> {
 581        self.request_completion_impl(
 582            project,
 583            buffer,
 584            position,
 585            can_collect_data,
 586            cx,
 587            Self::perform_predict_edits,
 588        )
 589    }
 590
 591    pub fn perform_predict_edits(
 592        params: PerformPredictEditsParams,
 593    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 594        async move {
 595            let PerformPredictEditsParams {
 596                client,
 597                llm_token,
 598                app_version,
 599                body,
 600                ..
 601            } = params;
 602
 603            let http_client = client.http_client();
 604            let mut token = llm_token.acquire(&client).await?;
 605            let mut did_retry = false;
 606
 607            loop {
 608                let request_builder = http_client::Request::builder().method(Method::POST);
 609                let request_builder =
 610                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 611                        request_builder.uri(predict_edits_url)
 612                    } else {
 613                        request_builder.uri(
 614                            http_client
 615                                .build_zed_llm_url("/predict_edits/v2", &[])?
 616                                .as_ref(),
 617                        )
 618                    };
 619                let request = request_builder
 620                    .header("Content-Type", "application/json")
 621                    .header("Authorization", format!("Bearer {}", token))
 622                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 623                    .body(serde_json::to_string(&body)?.into())?;
 624
 625                let mut response = http_client.send(request).await?;
 626
 627                if let Some(minimum_required_version) = response
 628                    .headers()
 629                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 630                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 631                {
 632                    anyhow::ensure!(
 633                        app_version >= minimum_required_version,
 634                        ZedUpdateRequiredError {
 635                            minimum_version: minimum_required_version
 636                        }
 637                    );
 638                }
 639
 640                if response.status().is_success() {
 641                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 642
 643                    let mut body = String::new();
 644                    response.body_mut().read_to_string(&mut body).await?;
 645                    return Ok((serde_json::from_str(&body)?, usage));
 646                } else if !did_retry
 647                    && response
 648                        .headers()
 649                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 650                        .is_some()
 651                {
 652                    did_retry = true;
 653                    token = llm_token.refresh(&client).await?;
 654                } else {
 655                    let mut body = String::new();
 656                    response.body_mut().read_to_string(&mut body).await?;
 657                    anyhow::bail!(
 658                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 659                        response.status(),
 660                        body
 661                    );
 662                }
 663            }
 664        }
 665    }
 666
 667    fn accept_edit_prediction(
 668        &mut self,
 669        request_id: EditPredictionId,
 670        cx: &mut Context<Self>,
 671    ) -> Task<Result<()>> {
 672        let client = self.client.clone();
 673        let llm_token = self.llm_token.clone();
 674        let app_version = AppVersion::global(cx);
 675        cx.spawn(async move |this, cx| {
 676            let http_client = client.http_client();
 677            let mut response = llm_token_retry(&llm_token, &client, |token| {
 678                let request_builder = http_client::Request::builder().method(Method::POST);
 679                let request_builder =
 680                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 681                        request_builder.uri(accept_prediction_url)
 682                    } else {
 683                        request_builder.uri(
 684                            http_client
 685                                .build_zed_llm_url("/predict_edits/accept", &[])?
 686                                .as_ref(),
 687                        )
 688                    };
 689                Ok(request_builder
 690                    .header("Content-Type", "application/json")
 691                    .header("Authorization", format!("Bearer {}", token))
 692                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 693                    .body(
 694                        serde_json::to_string(&AcceptEditPredictionBody {
 695                            request_id: request_id.0,
 696                        })?
 697                        .into(),
 698                    )?)
 699            })
 700            .await?;
 701
 702            if let Some(minimum_required_version) = response
 703                .headers()
 704                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 705                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 706                && app_version < minimum_required_version
 707            {
 708                return Err(anyhow!(ZedUpdateRequiredError {
 709                    minimum_version: minimum_required_version
 710                }));
 711            }
 712
 713            if response.status().is_success() {
 714                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 715                    this.update(cx, |this, cx| {
 716                        this.user_store.update(cx, |user_store, cx| {
 717                            user_store.update_edit_prediction_usage(usage, cx);
 718                        });
 719                    })?;
 720                }
 721
 722                Ok(())
 723            } else {
 724                let mut body = String::new();
 725                response.body_mut().read_to_string(&mut body).await?;
 726                Err(anyhow!(
 727                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 728                    response.status(),
 729                    body
 730                ))
 731            }
 732        })
 733    }
 734
 735    fn process_completion_response(
 736        prediction_response: PredictEditsResponse,
 737        buffer: Entity<Buffer>,
 738        snapshot: &BufferSnapshot,
 739        editable_range: Range<usize>,
 740        cursor_offset: usize,
 741        path: Arc<Path>,
 742        input_outline: String,
 743        input_events: String,
 744        input_excerpt: String,
 745        buffer_snapshotted_at: Instant,
 746        cx: &AsyncApp,
 747    ) -> Task<Result<Option<EditPrediction>>> {
 748        let snapshot = snapshot.clone();
 749        let request_id = prediction_response.request_id;
 750        let output_excerpt = prediction_response.output_excerpt;
 751        cx.spawn(async move |cx| {
 752            let output_excerpt: Arc<str> = output_excerpt.into();
 753
 754            let edits: Arc<[(Range<Anchor>, String)]> = cx
 755                .background_spawn({
 756                    let output_excerpt = output_excerpt.clone();
 757                    let editable_range = editable_range.clone();
 758                    let snapshot = snapshot.clone();
 759                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 760                })
 761                .await?
 762                .into();
 763
 764            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 765                let edits = edits.clone();
 766                |buffer, cx| {
 767                    let new_snapshot = buffer.snapshot();
 768                    let edits: Arc<[(Range<Anchor>, String)]> =
 769                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 770                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 771                }
 772            })?
 773            else {
 774                return anyhow::Ok(None);
 775            };
 776
 777            let edit_preview = edit_preview.await;
 778
 779            Ok(Some(EditPrediction {
 780                id: EditPredictionId(request_id),
 781                path,
 782                excerpt_range: editable_range,
 783                cursor_offset,
 784                edits,
 785                edit_preview,
 786                snapshot,
 787                input_outline: input_outline.into(),
 788                input_events: input_events.into(),
 789                input_excerpt: input_excerpt.into(),
 790                output_excerpt,
 791                buffer_snapshotted_at,
 792                response_received_at: Instant::now(),
 793            }))
 794        })
 795    }
 796
 797    fn parse_edits(
 798        output_excerpt: Arc<str>,
 799        editable_range: Range<usize>,
 800        snapshot: &BufferSnapshot,
 801    ) -> Result<Vec<(Range<Anchor>, String)>> {
 802        let content = output_excerpt.replace(CURSOR_MARKER, "");
 803
 804        let start_markers = content
 805            .match_indices(EDITABLE_REGION_START_MARKER)
 806            .collect::<Vec<_>>();
 807        anyhow::ensure!(
 808            start_markers.len() == 1,
 809            "expected exactly one start marker, found {}",
 810            start_markers.len()
 811        );
 812
 813        let end_markers = content
 814            .match_indices(EDITABLE_REGION_END_MARKER)
 815            .collect::<Vec<_>>();
 816        anyhow::ensure!(
 817            end_markers.len() == 1,
 818            "expected exactly one end marker, found {}",
 819            end_markers.len()
 820        );
 821
 822        let sof_markers = content
 823            .match_indices(START_OF_FILE_MARKER)
 824            .collect::<Vec<_>>();
 825        anyhow::ensure!(
 826            sof_markers.len() <= 1,
 827            "expected at most one start-of-file marker, found {}",
 828            sof_markers.len()
 829        );
 830
 831        let codefence_start = start_markers[0].0;
 832        let content = &content[codefence_start..];
 833
 834        let newline_ix = content.find('\n').context("could not find newline")?;
 835        let content = &content[newline_ix + 1..];
 836
 837        let codefence_end = content
 838            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 839            .context("could not find end marker")?;
 840        let new_text = &content[..codefence_end];
 841
 842        let old_text = snapshot
 843            .text_for_range(editable_range.clone())
 844            .collect::<String>();
 845
 846        Ok(Self::compute_edits(
 847            old_text,
 848            new_text,
 849            editable_range.start,
 850            snapshot,
 851        ))
 852    }
 853
 854    pub fn compute_edits(
 855        old_text: String,
 856        new_text: &str,
 857        offset: usize,
 858        snapshot: &BufferSnapshot,
 859    ) -> Vec<(Range<Anchor>, String)> {
 860        text_diff(&old_text, new_text)
 861            .into_iter()
 862            .map(|(mut old_range, new_text)| {
 863                old_range.start += offset;
 864                old_range.end += offset;
 865
 866                let prefix_len = common_prefix(
 867                    snapshot.chars_for_range(old_range.clone()),
 868                    new_text.chars(),
 869                );
 870                old_range.start += prefix_len;
 871
 872                let suffix_len = common_prefix(
 873                    snapshot.reversed_chars_for_range(old_range.clone()),
 874                    new_text[prefix_len..].chars().rev(),
 875                );
 876                old_range.end = old_range.end.saturating_sub(suffix_len);
 877
 878                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 879                let range = if old_range.is_empty() {
 880                    let anchor = snapshot.anchor_after(old_range.start);
 881                    anchor..anchor
 882                } else {
 883                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 884                };
 885                (range, new_text)
 886            })
 887            .collect()
 888    }
 889
 890    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
 891        self.rated_completions.contains(&completion_id)
 892    }
 893
 894    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
 895        self.shown_completions.push_front(completion.clone());
 896        if self.shown_completions.len() > 50 {
 897            let completion = self.shown_completions.pop_back().unwrap();
 898            self.rated_completions.remove(&completion.id);
 899        }
 900        cx.notify();
 901    }
 902
 903    pub fn rate_completion(
 904        &mut self,
 905        completion: &EditPrediction,
 906        rating: EditPredictionRating,
 907        feedback: String,
 908        cx: &mut Context<Self>,
 909    ) {
 910        self.rated_completions.insert(completion.id);
 911        telemetry::event!(
 912            "Edit Prediction Rated",
 913            rating,
 914            input_events = completion.input_events,
 915            input_excerpt = completion.input_excerpt,
 916            input_outline = completion.input_outline,
 917            output_excerpt = completion.output_excerpt,
 918            feedback
 919        );
 920        self.client.telemetry().flush_events().detach();
 921        cx.notify();
 922    }
 923
 924    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
 925        self.shown_completions.iter()
 926    }
 927
 928    pub fn shown_completions_len(&self) -> usize {
 929        self.shown_completions.len()
 930    }
 931
 932    fn report_changes_for_buffer(
 933        &mut self,
 934        buffer: &Entity<Buffer>,
 935        project: &Entity<Project>,
 936        cx: &mut Context<Self>,
 937    ) -> BufferSnapshot {
 938        let zeta_project = self.get_or_init_zeta_project(project, cx);
 939        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 940
 941        let new_snapshot = buffer.read(cx).snapshot();
 942        if new_snapshot.version != registered_buffer.snapshot.version {
 943            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 944            Self::push_event(
 945                zeta_project,
 946                Event::BufferChange {
 947                    old_snapshot,
 948                    new_snapshot: new_snapshot.clone(),
 949                    timestamp: Instant::now(),
 950                },
 951            );
 952        }
 953
 954        new_snapshot
 955    }
 956
 957    fn load_data_collection_choices() -> DataCollectionChoice {
 958        let choice = KEY_VALUE_STORE
 959            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 960            .log_err()
 961            .flatten();
 962
 963        match choice.as_deref() {
 964            Some("true") => DataCollectionChoice::Enabled,
 965            Some("false") => DataCollectionChoice::Disabled,
 966            Some(_) => {
 967                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
 968                DataCollectionChoice::NotAnswered
 969            }
 970            None => DataCollectionChoice::NotAnswered,
 971        }
 972    }
 973}
 974
 975pub struct PerformPredictEditsParams {
 976    pub client: Arc<Client>,
 977    pub llm_token: LlmApiToken,
 978    pub app_version: SemanticVersion,
 979    pub body: PredictEditsBody,
 980}
 981
 982#[derive(Error, Debug)]
 983#[error(
 984    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 985)]
 986pub struct ZedUpdateRequiredError {
 987    minimum_version: SemanticVersion,
 988}
 989
 990fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 991    a.zip(b)
 992        .take_while(|(a, b)| a == b)
 993        .map(|(a, _)| a.len_utf8())
 994        .sum()
 995}
 996
 997fn git_info_for_file(
 998    project: &Entity<Project>,
 999    project_path: &ProjectPath,
1000    cx: &App,
1001) -> Option<PredictEditsGitInfo> {
1002    let git_store = project.read(cx).git_store().read(cx);
1003    if let Some((repository, _repo_path)) =
1004        git_store.repository_and_path_for_project_path(project_path, cx)
1005    {
1006        let repository = repository.read(cx);
1007        let head_sha = repository
1008            .head_commit
1009            .as_ref()
1010            .map(|head_commit| head_commit.sha.to_string());
1011        let remote_origin_url = repository.remote_origin_url.clone();
1012        let remote_upstream_url = repository.remote_upstream_url.clone();
1013        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1014            return None;
1015        }
1016        Some(PredictEditsGitInfo {
1017            head_sha,
1018            remote_origin_url,
1019            remote_upstream_url,
1020        })
1021    } else {
1022        None
1023    }
1024}
1025
1026pub struct GatherContextOutput {
1027    pub body: PredictEditsBody,
1028    pub editable_range: Range<usize>,
1029}
1030
1031pub fn gather_context(
1032    project: &Entity<Project>,
1033    full_path_str: String,
1034    snapshot: &BufferSnapshot,
1035    cursor_point: language::Point,
1036    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1037    can_collect_data: bool,
1038    git_info: Option<PredictEditsGitInfo>,
1039    cx: &App,
1040) -> Task<Result<GatherContextOutput>> {
1041    let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local();
1042    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1043        if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1044            snapshot
1045                .diagnostic_groups(None)
1046                .into_iter()
1047                .filter_map(|(language_server_id, diagnostic_group)| {
1048                    let language_server =
1049                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1050                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1051                    let language_server_name = language_server.name().to_string();
1052                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1053                    Some((language_server_name, serialized))
1054                })
1055                .collect::<Vec<_>>()
1056        } else {
1057            Vec::new()
1058        };
1059
1060    cx.background_spawn({
1061        let snapshot = snapshot.clone();
1062        async move {
1063            let diagnostic_groups = if diagnostic_groups.is_empty()
1064                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1065            {
1066                None
1067            } else {
1068                Some(diagnostic_groups)
1069            };
1070
1071            let input_excerpt = excerpt_for_cursor_position(
1072                cursor_point,
1073                &full_path_str,
1074                &snapshot,
1075                MAX_REWRITE_TOKENS,
1076                MAX_CONTEXT_TOKENS,
1077            );
1078            let input_events = make_events_prompt();
1079            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1080
1081            let body = PredictEditsBody {
1082                input_events,
1083                input_excerpt: input_excerpt.prompt,
1084                can_collect_data,
1085                diagnostic_groups,
1086                git_info,
1087                outline: None,
1088                speculated_output: None,
1089            };
1090
1091            Ok(GatherContextOutput {
1092                body,
1093                editable_range,
1094            })
1095        }
1096    })
1097}
1098
1099fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1100    let mut result = String::new();
1101    for event in events.iter().rev() {
1102        let event_string = event.to_prompt();
1103        let event_tokens = tokens_for_bytes(event_string.len());
1104        if event_tokens > remaining_tokens {
1105            break;
1106        }
1107
1108        if !result.is_empty() {
1109            result.insert_str(0, "\n\n");
1110        }
1111        result.insert_str(0, &event_string);
1112        remaining_tokens -= event_tokens;
1113    }
1114    result
1115}
1116
1117struct RegisteredBuffer {
1118    snapshot: BufferSnapshot,
1119    _subscriptions: [gpui::Subscription; 2],
1120}
1121
1122#[derive(Clone)]
1123pub enum Event {
1124    BufferChange {
1125        old_snapshot: BufferSnapshot,
1126        new_snapshot: BufferSnapshot,
1127        timestamp: Instant,
1128    },
1129}
1130
1131impl Event {
1132    fn to_prompt(&self) -> String {
1133        match self {
1134            Event::BufferChange {
1135                old_snapshot,
1136                new_snapshot,
1137                ..
1138            } => {
1139                let mut prompt = String::new();
1140
1141                let old_path = old_snapshot
1142                    .file()
1143                    .map(|f| f.path().as_ref())
1144                    .unwrap_or(Path::new("untitled"));
1145                let new_path = new_snapshot
1146                    .file()
1147                    .map(|f| f.path().as_ref())
1148                    .unwrap_or(Path::new("untitled"));
1149                if old_path != new_path {
1150                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1151                }
1152
1153                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1154                if !diff.is_empty() {
1155                    write!(
1156                        prompt,
1157                        "User edited {:?}:\n```diff\n{}\n```",
1158                        new_path, diff
1159                    )
1160                    .unwrap();
1161                }
1162
1163                prompt
1164            }
1165        }
1166    }
1167}
1168
1169#[derive(Debug, Clone)]
1170struct CurrentEditPrediction {
1171    buffer_id: EntityId,
1172    completion: EditPrediction,
1173}
1174
1175impl CurrentEditPrediction {
1176    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1177        if self.buffer_id != old_completion.buffer_id {
1178            return true;
1179        }
1180
1181        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1182            return true;
1183        };
1184        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1185            return false;
1186        };
1187
1188        if old_edits.len() == 1 && new_edits.len() == 1 {
1189            let (old_range, old_text) = &old_edits[0];
1190            let (new_range, new_text) = &new_edits[0];
1191            new_range == old_range && new_text.starts_with(old_text)
1192        } else {
1193            true
1194        }
1195    }
1196}
1197
1198struct PendingCompletion {
1199    id: usize,
1200    _task: Task<()>,
1201}
1202
1203#[derive(Debug, Clone, Copy)]
1204pub enum DataCollectionChoice {
1205    NotAnswered,
1206    Enabled,
1207    Disabled,
1208}
1209
1210impl DataCollectionChoice {
1211    pub fn is_enabled(self) -> bool {
1212        match self {
1213            Self::Enabled => true,
1214            Self::NotAnswered | Self::Disabled => false,
1215        }
1216    }
1217
1218    pub fn is_answered(self) -> bool {
1219        match self {
1220            Self::Enabled | Self::Disabled => true,
1221            Self::NotAnswered => false,
1222        }
1223    }
1224
1225    pub fn toggle(&self) -> DataCollectionChoice {
1226        match self {
1227            Self::Enabled => Self::Disabled,
1228            Self::Disabled => Self::Enabled,
1229            Self::NotAnswered => Self::Enabled,
1230        }
1231    }
1232}
1233
1234impl From<bool> for DataCollectionChoice {
1235    fn from(value: bool) -> Self {
1236        match value {
1237            true => DataCollectionChoice::Enabled,
1238            false => DataCollectionChoice::Disabled,
1239        }
1240    }
1241}
1242
1243pub struct ProviderDataCollection {
1244    /// When set to None, data collection is not possible in the provider buffer
1245    choice: Option<Entity<DataCollectionChoice>>,
1246    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1247}
1248
1249impl ProviderDataCollection {
1250    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1251        let choice_and_watcher = buffer.and_then(|buffer| {
1252            let file = buffer.read(cx).file()?;
1253
1254            if !file.is_local() || file.is_private() {
1255                return None;
1256            }
1257
1258            let zeta = zeta.read(cx);
1259            let choice = zeta.data_collection_choice.clone();
1260
1261            let license_detection_watcher = zeta
1262                .license_detection_watchers
1263                .get(&file.worktree_id(cx))
1264                .cloned()?;
1265
1266            Some((choice, license_detection_watcher))
1267        });
1268
1269        if let Some((choice, watcher)) = choice_and_watcher {
1270            ProviderDataCollection {
1271                choice: Some(choice),
1272                license_detection_watcher: Some(watcher),
1273            }
1274        } else {
1275            ProviderDataCollection {
1276                choice: None,
1277                license_detection_watcher: None,
1278            }
1279        }
1280    }
1281
1282    pub fn can_collect_data(&self, cx: &App) -> bool {
1283        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1284    }
1285
1286    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1287        self.choice
1288            .as_ref()
1289            .is_some_and(|choice| choice.read(cx).is_enabled())
1290    }
1291
1292    fn is_project_open_source(&self) -> bool {
1293        self.license_detection_watcher
1294            .as_ref()
1295            .is_some_and(|watcher| watcher.is_project_open_source())
1296    }
1297
1298    pub fn toggle(&mut self, cx: &mut App) {
1299        if let Some(choice) = self.choice.as_mut() {
1300            let new_choice = choice.update(cx, |choice, _cx| {
1301                let new_choice = choice.toggle();
1302                *choice = new_choice;
1303                new_choice
1304            });
1305
1306            db::write_and_log(cx, move || {
1307                KEY_VALUE_STORE.write_kvp(
1308                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1309                    new_choice.is_enabled().to_string(),
1310                )
1311            });
1312        }
1313    }
1314}
1315
1316async fn llm_token_retry(
1317    llm_token: &LlmApiToken,
1318    client: &Arc<Client>,
1319    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1320) -> Result<Response<AsyncBody>> {
1321    let mut did_retry = false;
1322    let http_client = client.http_client();
1323    let mut token = llm_token.acquire(client).await?;
1324    loop {
1325        let request = build_request(token.clone())?;
1326        let response = http_client.send(request).await?;
1327
1328        if !did_retry
1329            && !response.status().is_success()
1330            && response
1331                .headers()
1332                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1333                .is_some()
1334        {
1335            did_retry = true;
1336            token = llm_token.refresh(client).await?;
1337            continue;
1338        }
1339
1340        return Ok(response);
1341    }
1342}
1343
1344pub struct ZetaEditPredictionProvider {
1345    zeta: Entity<Zeta>,
1346    pending_completions: ArrayVec<PendingCompletion, 2>,
1347    next_pending_completion_id: usize,
1348    current_completion: Option<CurrentEditPrediction>,
1349    /// None if this is entirely disabled for this provider
1350    provider_data_collection: ProviderDataCollection,
1351    last_request_timestamp: Instant,
1352}
1353
1354impl ZetaEditPredictionProvider {
1355    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1356
1357    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1358        Self {
1359            zeta,
1360            pending_completions: ArrayVec::new(),
1361            next_pending_completion_id: 0,
1362            current_completion: None,
1363            provider_data_collection,
1364            last_request_timestamp: Instant::now(),
1365        }
1366    }
1367}
1368
1369impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1370    fn name() -> &'static str {
1371        "zed-predict"
1372    }
1373
1374    fn display_name() -> &'static str {
1375        "Zed's Edit Predictions"
1376    }
1377
1378    fn show_completions_in_menu() -> bool {
1379        true
1380    }
1381
1382    fn show_tab_accept_marker() -> bool {
1383        true
1384    }
1385
1386    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1387        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1388
1389        if self.provider_data_collection.is_data_collection_enabled(cx) {
1390            DataCollectionState::Enabled {
1391                is_project_open_source,
1392            }
1393        } else {
1394            DataCollectionState::Disabled {
1395                is_project_open_source,
1396            }
1397        }
1398    }
1399
1400    fn toggle_data_collection(&mut self, cx: &mut App) {
1401        self.provider_data_collection.toggle(cx);
1402    }
1403
1404    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1405        self.zeta.read(cx).usage(cx)
1406    }
1407
1408    fn is_enabled(
1409        &self,
1410        _buffer: &Entity<Buffer>,
1411        _cursor_position: language::Anchor,
1412        _cx: &App,
1413    ) -> bool {
1414        true
1415    }
1416    fn is_refreshing(&self) -> bool {
1417        !self.pending_completions.is_empty()
1418    }
1419
1420    fn refresh(
1421        &mut self,
1422        project: Option<Entity<Project>>,
1423        buffer: Entity<Buffer>,
1424        position: language::Anchor,
1425        _debounce: bool,
1426        cx: &mut Context<Self>,
1427    ) {
1428        if self.zeta.read(cx).update_required {
1429            return;
1430        }
1431        let Some(project) = project else {
1432            return;
1433        };
1434
1435        if self
1436            .zeta
1437            .read(cx)
1438            .user_store
1439            .read_with(cx, |user_store, _cx| {
1440                user_store.account_too_young() || user_store.has_overdue_invoices()
1441            })
1442        {
1443            return;
1444        }
1445
1446        if let Some(current_completion) = self.current_completion.as_ref() {
1447            let snapshot = buffer.read(cx).snapshot();
1448            if current_completion
1449                .completion
1450                .interpolate(&snapshot)
1451                .is_some()
1452            {
1453                return;
1454            }
1455        }
1456
1457        let pending_completion_id = self.next_pending_completion_id;
1458        self.next_pending_completion_id += 1;
1459        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1460        let last_request_timestamp = self.last_request_timestamp;
1461
1462        let task = cx.spawn(async move |this, cx| {
1463            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1464                .checked_duration_since(Instant::now())
1465            {
1466                cx.background_executor().timer(timeout).await;
1467            }
1468
1469            let completion_request = this.update(cx, |this, cx| {
1470                this.last_request_timestamp = Instant::now();
1471                this.zeta.update(cx, |zeta, cx| {
1472                    zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
1473                })
1474            });
1475
1476            let completion = match completion_request {
1477                Ok(completion_request) => {
1478                    let completion_request = completion_request.await;
1479                    completion_request.map(|c| {
1480                        c.map(|completion| CurrentEditPrediction {
1481                            buffer_id: buffer.entity_id(),
1482                            completion,
1483                        })
1484                    })
1485                }
1486                Err(error) => Err(error),
1487            };
1488            let Some(new_completion) = completion
1489                .context("edit prediction failed")
1490                .log_err()
1491                .flatten()
1492            else {
1493                this.update(cx, |this, cx| {
1494                    if this.pending_completions[0].id == pending_completion_id {
1495                        this.pending_completions.remove(0);
1496                    } else {
1497                        this.pending_completions.clear();
1498                    }
1499
1500                    cx.notify();
1501                })
1502                .ok();
1503                return;
1504            };
1505
1506            this.update(cx, |this, cx| {
1507                if this.pending_completions[0].id == pending_completion_id {
1508                    this.pending_completions.remove(0);
1509                } else {
1510                    this.pending_completions.clear();
1511                }
1512
1513                if let Some(old_completion) = this.current_completion.as_ref() {
1514                    let snapshot = buffer.read(cx).snapshot();
1515                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1516                        this.zeta.update(cx, |zeta, cx| {
1517                            zeta.completion_shown(&new_completion.completion, cx);
1518                        });
1519                        this.current_completion = Some(new_completion);
1520                    }
1521                } else {
1522                    this.zeta.update(cx, |zeta, cx| {
1523                        zeta.completion_shown(&new_completion.completion, cx);
1524                    });
1525                    this.current_completion = Some(new_completion);
1526                }
1527
1528                cx.notify();
1529            })
1530            .ok();
1531        });
1532
1533        // We always maintain at most two pending completions. When we already
1534        // have two, we replace the newest one.
1535        if self.pending_completions.len() <= 1 {
1536            self.pending_completions.push(PendingCompletion {
1537                id: pending_completion_id,
1538                _task: task,
1539            });
1540        } else if self.pending_completions.len() == 2 {
1541            self.pending_completions.pop();
1542            self.pending_completions.push(PendingCompletion {
1543                id: pending_completion_id,
1544                _task: task,
1545            });
1546        }
1547    }
1548
1549    fn cycle(
1550        &mut self,
1551        _buffer: Entity<Buffer>,
1552        _cursor_position: language::Anchor,
1553        _direction: edit_prediction::Direction,
1554        _cx: &mut Context<Self>,
1555    ) {
1556        // Right now we don't support cycling.
1557    }
1558
1559    fn accept(&mut self, cx: &mut Context<Self>) {
1560        let completion_id = self
1561            .current_completion
1562            .as_ref()
1563            .map(|completion| completion.completion.id);
1564        if let Some(completion_id) = completion_id {
1565            self.zeta
1566                .update(cx, |zeta, cx| {
1567                    zeta.accept_edit_prediction(completion_id, cx)
1568                })
1569                .detach();
1570        }
1571        self.pending_completions.clear();
1572    }
1573
1574    fn discard(&mut self, _cx: &mut Context<Self>) {
1575        self.pending_completions.clear();
1576        self.current_completion.take();
1577    }
1578
1579    fn suggest(
1580        &mut self,
1581        buffer: &Entity<Buffer>,
1582        cursor_position: language::Anchor,
1583        cx: &mut Context<Self>,
1584    ) -> Option<edit_prediction::EditPrediction> {
1585        let CurrentEditPrediction {
1586            buffer_id,
1587            completion,
1588            ..
1589        } = self.current_completion.as_mut()?;
1590
1591        // Invalidate previous completion if it was generated for a different buffer.
1592        if *buffer_id != buffer.entity_id() {
1593            self.current_completion.take();
1594            return None;
1595        }
1596
1597        let buffer = buffer.read(cx);
1598        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1599            self.current_completion.take();
1600            return None;
1601        };
1602
1603        let cursor_row = cursor_position.to_point(buffer).row;
1604        let (closest_edit_ix, (closest_edit_range, _)) =
1605            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1606                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1607                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1608                cmp::min(distance_from_start, distance_from_end)
1609            })?;
1610
1611        let mut edit_start_ix = closest_edit_ix;
1612        for (range, _) in edits[..edit_start_ix].iter().rev() {
1613            let distance_from_closest_edit =
1614                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1615            if distance_from_closest_edit <= 1 {
1616                edit_start_ix -= 1;
1617            } else {
1618                break;
1619            }
1620        }
1621
1622        let mut edit_end_ix = closest_edit_ix + 1;
1623        for (range, _) in &edits[edit_end_ix..] {
1624            let distance_from_closest_edit =
1625                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1626            if distance_from_closest_edit <= 1 {
1627                edit_end_ix += 1;
1628            } else {
1629                break;
1630            }
1631        }
1632
1633        Some(edit_prediction::EditPrediction {
1634            id: Some(completion.id.to_string().into()),
1635            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1636            edit_preview: Some(completion.edit_preview.clone()),
1637        })
1638    }
1639}
1640
1641fn tokens_for_bytes(bytes: usize) -> usize {
1642    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1643    /// intentionally low to err on the side of underestimating limits.
1644    const BYTES_PER_TOKEN_GUESS: usize = 3;
1645    bytes / BYTES_PER_TOKEN_GUESS
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650    use client::test::FakeServer;
1651    use clock::FakeSystemClock;
1652    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1653    use gpui::TestAppContext;
1654    use http_client::FakeHttpClient;
1655    use indoc::indoc;
1656    use language::Point;
1657    use settings::SettingsStore;
1658    use util::path;
1659
1660    use super::*;
1661
1662    #[gpui::test]
1663    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1664        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1665        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1666            to_completion_edits(
1667                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1668                &buffer,
1669                cx,
1670            )
1671            .into()
1672        });
1673
1674        let edit_preview = cx
1675            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1676            .await;
1677
1678        let completion = EditPrediction {
1679            edits,
1680            edit_preview,
1681            path: Path::new("").into(),
1682            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1683            id: EditPredictionId(Uuid::new_v4()),
1684            excerpt_range: 0..0,
1685            cursor_offset: 0,
1686            input_outline: "".into(),
1687            input_events: "".into(),
1688            input_excerpt: "".into(),
1689            output_excerpt: "".into(),
1690            buffer_snapshotted_at: Instant::now(),
1691            response_received_at: Instant::now(),
1692        };
1693
1694        cx.update(|cx| {
1695            assert_eq!(
1696                from_completion_edits(
1697                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1698                    &buffer,
1699                    cx
1700                ),
1701                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1702            );
1703
1704            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1705            assert_eq!(
1706                from_completion_edits(
1707                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1708                    &buffer,
1709                    cx
1710                ),
1711                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1712            );
1713
1714            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1715            assert_eq!(
1716                from_completion_edits(
1717                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1718                    &buffer,
1719                    cx
1720                ),
1721                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1722            );
1723
1724            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1725            assert_eq!(
1726                from_completion_edits(
1727                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1728                    &buffer,
1729                    cx
1730                ),
1731                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1732            );
1733
1734            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1735            assert_eq!(
1736                from_completion_edits(
1737                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1738                    &buffer,
1739                    cx
1740                ),
1741                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1742            );
1743
1744            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1745            assert_eq!(
1746                from_completion_edits(
1747                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1748                    &buffer,
1749                    cx
1750                ),
1751                vec![(9..11, "".to_string())]
1752            );
1753
1754            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1755            assert_eq!(
1756                from_completion_edits(
1757                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1758                    &buffer,
1759                    cx
1760                ),
1761                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1762            );
1763
1764            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1765            assert_eq!(
1766                from_completion_edits(
1767                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1768                    &buffer,
1769                    cx
1770                ),
1771                vec![(4..4, "M".to_string())]
1772            );
1773
1774            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1775            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1776        })
1777    }
1778
1779    #[gpui::test]
1780    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1781        cx.update(|cx| {
1782            let settings_store = SettingsStore::test(cx);
1783            cx.set_global(settings_store);
1784            client::init_settings(cx);
1785            Project::init_settings(cx);
1786        });
1787
1788        let edits = edits_for_prediction(
1789            indoc! {"
1790                fn main() {
1791                    let word_1 = \"lorem\";
1792                    let range = word.len()..word.len();
1793                }
1794            "},
1795            indoc! {"
1796                <|editable_region_start|>
1797                fn main() {
1798                    let word_1 = \"lorem\";
1799                    let range = word_1.len()..word_1.len();
1800                }
1801
1802                <|editable_region_end|>
1803            "},
1804            cx,
1805        )
1806        .await;
1807        assert_eq!(
1808            edits,
1809            [
1810                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1811                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1812            ]
1813        );
1814
1815        let edits = edits_for_prediction(
1816            indoc! {"
1817                fn main() {
1818                    let story = \"the quick\"
1819                }
1820            "},
1821            indoc! {"
1822                <|editable_region_start|>
1823                fn main() {
1824                    let story = \"the quick brown fox jumps over the lazy dog\";
1825                }
1826
1827                <|editable_region_end|>
1828            "},
1829            cx,
1830        )
1831        .await;
1832        assert_eq!(
1833            edits,
1834            [
1835                (
1836                    Point::new(1, 26)..Point::new(1, 26),
1837                    " brown fox jumps over the lazy dog".to_string()
1838                ),
1839                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1840            ]
1841        );
1842    }
1843
1844    #[gpui::test]
1845    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1846        cx.update(|cx| {
1847            let settings_store = SettingsStore::test(cx);
1848            cx.set_global(settings_store);
1849            client::init_settings(cx);
1850            Project::init_settings(cx);
1851        });
1852
1853        let buffer_content = "lorem\n";
1854        let completion_response = indoc! {"
1855            ```animals.js
1856            <|start_of_file|>
1857            <|editable_region_start|>
1858            lorem
1859            ipsum
1860            <|editable_region_end|>
1861            ```"};
1862
1863        let http_client = FakeHttpClient::create(move |req| async move {
1864            match (req.method(), req.uri().path()) {
1865                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1866                    .status(200)
1867                    .body(
1868                        serde_json::to_string(&CreateLlmTokenResponse {
1869                            token: LlmToken("the-llm-token".to_string()),
1870                        })
1871                        .unwrap()
1872                        .into(),
1873                    )
1874                    .unwrap()),
1875                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
1876                    .status(200)
1877                    .body(
1878                        serde_json::to_string(&PredictEditsResponse {
1879                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1880                                .unwrap(),
1881                            output_excerpt: completion_response.to_string(),
1882                        })
1883                        .unwrap()
1884                        .into(),
1885                    )
1886                    .unwrap()),
1887                _ => Ok(http_client::Response::builder()
1888                    .status(404)
1889                    .body("Not Found".into())
1890                    .unwrap()),
1891            }
1892        });
1893
1894        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1895        cx.update(|cx| {
1896            RefreshLlmTokenListener::register(client.clone(), cx);
1897        });
1898        // Construct the fake server to authenticate.
1899        let _server = FakeServer::for_client(42, &client, cx).await;
1900        let fs = project::FakeFs::new(cx.executor());
1901        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1902        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1903        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1904
1905        let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
1906        let completion_task = zeta.update(cx, |zeta, cx| {
1907            zeta.request_completion(&project, &buffer, cursor, false, cx)
1908        });
1909
1910        let completion = completion_task.await.unwrap().unwrap();
1911        buffer.update(cx, |buffer, cx| {
1912            buffer.edit(completion.edits.iter().cloned(), None, cx)
1913        });
1914        assert_eq!(
1915            buffer.read_with(cx, |buffer, _| buffer.text()),
1916            "lorem\nipsum"
1917        );
1918    }
1919
1920    async fn edits_for_prediction(
1921        buffer_content: &str,
1922        completion_response: &str,
1923        cx: &mut TestAppContext,
1924    ) -> Vec<(Range<Point>, String)> {
1925        let completion_response = completion_response.to_string();
1926        let http_client = FakeHttpClient::create(move |req| {
1927            let completion = completion_response.clone();
1928            async move {
1929                match (req.method(), req.uri().path()) {
1930                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1931                        .status(200)
1932                        .body(
1933                            serde_json::to_string(&CreateLlmTokenResponse {
1934                                token: LlmToken("the-llm-token".to_string()),
1935                            })
1936                            .unwrap()
1937                            .into(),
1938                        )
1939                        .unwrap()),
1940                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
1941                        .status(200)
1942                        .body(
1943                            serde_json::to_string(&PredictEditsResponse {
1944                                request_id: Uuid::new_v4(),
1945                                output_excerpt: completion,
1946                            })
1947                            .unwrap()
1948                            .into(),
1949                        )
1950                        .unwrap()),
1951                    _ => Ok(http_client::Response::builder()
1952                        .status(404)
1953                        .body("Not Found".into())
1954                        .unwrap()),
1955                }
1956            }
1957        });
1958
1959        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1960        cx.update(|cx| {
1961            RefreshLlmTokenListener::register(client.clone(), cx);
1962        });
1963        // Construct the fake server to authenticate.
1964        let _server = FakeServer::for_client(42, &client, cx).await;
1965        let fs = project::FakeFs::new(cx.executor());
1966        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1967        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1968        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1969        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1970
1971        let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
1972        let completion_task = zeta.update(cx, |zeta, cx| {
1973            zeta.request_completion(&project, &buffer, cursor, false, cx)
1974        });
1975
1976        let completion = completion_task.await.unwrap().unwrap();
1977        completion
1978            .edits
1979            .iter()
1980            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1981            .collect::<Vec<_>>()
1982    }
1983
1984    fn to_completion_edits(
1985        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1986        buffer: &Entity<Buffer>,
1987        cx: &App,
1988    ) -> Vec<(Range<Anchor>, String)> {
1989        let buffer = buffer.read(cx);
1990        iterator
1991            .into_iter()
1992            .map(|(range, text)| {
1993                (
1994                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1995                    text,
1996                )
1997            })
1998            .collect()
1999    }
2000
2001    fn from_completion_edits(
2002        editor_edits: &[(Range<Anchor>, String)],
2003        buffer: &Entity<Buffer>,
2004        cx: &App,
2005    ) -> Vec<(Range<usize>, String)> {
2006        let buffer = buffer.read(cx);
2007        editor_edits
2008            .iter()
2009            .map(|(range, text)| {
2010                (
2011                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2012                    text.clone(),
2013                )
2014            })
2015            .collect()
2016    }
2017
2018    #[ctor::ctor]
2019    fn init_logger() {
2020        zlog::init_test();
2021    }
2022}