zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13use editor::Editor;
  14pub use init::*;
  15use license_detection::LicenseDetectionWatcher;
  16use project::git_store::Repository;
  17pub use rate_completion_modal::*;
  18
  19use anyhow::{Context as _, Result, anyhow};
  20use client::{Client, EditPredictionUsage, UserStore};
  21use cloud_llm_client::{
  22    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  23    PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
  24    PredictEditsResponse, ZED_VERSION_HEADER_NAME,
  25};
  26use collections::{HashMap, HashSet, VecDeque};
  27use futures::AsyncReadExt;
  28use gpui::{
  29    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  30    SharedString, Subscription, Task, actions,
  31};
  32use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  33use input_excerpt::excerpt_for_cursor_position;
  34use language::{
  35    Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  36};
  37use language_model::{LlmApiToken, RefreshLlmTokenListener};
  38use multi_buffer::MultiBufferPoint;
  39use project::{Project, ProjectPath};
  40use release_channel::AppVersion;
  41use settings::WorktreeId;
  42use std::str::FromStr;
  43use std::{
  44    cmp,
  45    fmt::Write,
  46    future::Future,
  47    mem,
  48    ops::Range,
  49    path::Path,
  50    rc::Rc,
  51    sync::Arc,
  52    time::{Duration, Instant},
  53};
  54use telemetry_events::EditPredictionRating;
  55use thiserror::Error;
  56use util::{ResultExt, maybe};
  57use uuid::Uuid;
  58use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  59use worktree::Worktree;
  60
  61const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  62const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  63const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  64const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  67
  68const MAX_CONTEXT_TOKENS: usize = 150;
  69const MAX_REWRITE_TOKENS: usize = 350;
  70const MAX_EVENT_TOKENS: usize = 500;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Minimum number of milliseconds between recent file entries.
  79const MIN_TIME_BETWEEN_RECENT_FILES: Duration = Duration::from_millis(100);
  80
  81/// Maximum file path length to include in recent files list.
  82const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
  83
  84/// Maximum number of JSON bytes for diagnostics in additional context.
  85const MAX_DIAGNOSTICS_BYTES: usize = 4096;
  86
  87/// Maximum number of edit predictions to store for feedback.
  88const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  89
  90/// Interval between polls tracking time editing files.
  91const ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(10);
  92
  93/// Interval between polls of whether data collection is enabled, when it is disabled.
  94const DISABLED_ACTIVITY_POLL_INTERVAL: Duration = Duration::from_secs(60 * 5);
  95
  96actions!(
  97    edit_prediction,
  98    [
  99        /// Clears the edit prediction history.
 100        ClearHistory
 101    ]
 102);
 103
 104#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 105pub struct EditPredictionId(Uuid);
 106
 107impl From<EditPredictionId> for gpui::ElementId {
 108    fn from(value: EditPredictionId) -> Self {
 109        gpui::ElementId::Uuid(value.0)
 110    }
 111}
 112
 113impl std::fmt::Display for EditPredictionId {
 114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 115        write!(f, "{}", self.0)
 116    }
 117}
 118
 119struct ZedPredictUpsell;
 120
 121impl Dismissable for ZedPredictUpsell {
 122    const KEY: &'static str = "dismissed-edit-predict-upsell";
 123
 124    fn dismissed() -> bool {
 125        // To make this backwards compatible with older versions of Zed, we
 126        // check if the user has seen the previous Edit Prediction Onboarding
 127        // before, by checking the data collection choice which was written to
 128        // the database once the user clicked on "Accept and Enable"
 129        if KEY_VALUE_STORE
 130            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 131            .log_err()
 132            .is_some_and(|s| s.is_some())
 133        {
 134            return true;
 135        }
 136
 137        KEY_VALUE_STORE
 138            .read_kvp(Self::KEY)
 139            .log_err()
 140            .is_some_and(|s| s.is_some())
 141    }
 142}
 143
 144pub fn should_show_upsell_modal() -> bool {
 145    !ZedPredictUpsell::dismissed()
 146}
 147
 148#[derive(Clone)]
 149struct ZetaGlobal(Entity<Zeta>);
 150
 151impl Global for ZetaGlobal {}
 152
 153#[derive(Clone)]
 154pub struct EditPrediction {
 155    id: EditPredictionId,
 156    path: Arc<Path>,
 157    excerpt_range: Range<usize>,
 158    cursor_offset: usize,
 159    edits: Arc<[(Range<Anchor>, String)]>,
 160    snapshot: BufferSnapshot,
 161    edit_preview: EditPreview,
 162    input_events: Arc<str>,
 163    input_excerpt: Arc<str>,
 164    output_excerpt: Arc<str>,
 165    buffer_snapshotted_at: Instant,
 166    response_received_at: Instant,
 167}
 168
 169impl EditPrediction {
 170    fn latency(&self) -> Duration {
 171        self.response_received_at
 172            .duration_since(self.buffer_snapshotted_at)
 173    }
 174
 175    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 176        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 177    }
 178}
 179
 180fn interpolate(
 181    old_snapshot: &BufferSnapshot,
 182    new_snapshot: &BufferSnapshot,
 183    current_edits: Arc<[(Range<Anchor>, String)]>,
 184) -> Option<Vec<(Range<Anchor>, String)>> {
 185    let mut edits = Vec::new();
 186
 187    let mut model_edits = current_edits.iter().peekable();
 188    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 189        while let Some((model_old_range, _)) = model_edits.peek() {
 190            let model_old_range = model_old_range.to_offset(old_snapshot);
 191            if model_old_range.end < user_edit.old.start {
 192                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 193                edits.push((model_old_range.clone(), model_new_text.clone()));
 194            } else {
 195                break;
 196            }
 197        }
 198
 199        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 200            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 201            if user_edit.old == model_old_offset_range {
 202                let user_new_text = new_snapshot
 203                    .text_for_range(user_edit.new.clone())
 204                    .collect::<String>();
 205
 206                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 207                    if !model_suffix.is_empty() {
 208                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 209                        edits.push((anchor..anchor, model_suffix.to_string()));
 210                    }
 211
 212                    model_edits.next();
 213                    continue;
 214                }
 215            }
 216        }
 217
 218        return None;
 219    }
 220
 221    edits.extend(model_edits.cloned());
 222
 223    if edits.is_empty() { None } else { Some(edits) }
 224}
 225
 226impl std::fmt::Debug for EditPrediction {
 227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 228        f.debug_struct("EditPrediction")
 229            .field("id", &self.id)
 230            .field("path", &self.path)
 231            .field("edits", &self.edits)
 232            .finish_non_exhaustive()
 233    }
 234}
 235
 236pub struct Zeta {
 237<<<<<<< HEAD
 238    workspace: WeakEntity<Workspace>,
 239=======
 240>>>>>>> main
 241    client: Arc<Client>,
 242    events: VecDeque<Event>,
 243    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 244    shown_completions: VecDeque<EditPrediction>,
 245    rated_completions: HashSet<EditPredictionId>,
 246    data_collection_choice: Entity<DataCollectionChoice>,
 247    llm_token: LlmApiToken,
 248    _llm_token_subscription: Subscription,
 249    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 250    update_required: bool,
 251    user_store: Entity<UserStore>,
 252    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 253    recent_editors: VecDeque<RecentEditor>,
 254    last_activity_state: Option<ActivityState>,
 255    _activity_poll_task: Option<Task<Result<()>>>,
 256}
 257
 258struct RecentEditor {
 259    editor: WeakEntity<Editor>,
 260    last_active_at: Instant,
 261    activation_count: u32,
 262    cumulative_time_editing: Duration,
 263    cumulative_time_navigating: Duration,
 264}
 265
 266#[derive(Debug)]
 267struct ActivityState {
 268    scroll_position: gpui::Point<f32>,
 269    cursor_point: MultiBufferPoint,
 270    singleton_version: Option<clock::Global>,
 271}
 272
 273impl Zeta {
 274    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 275        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 276    }
 277
 278    pub fn register(
 279<<<<<<< HEAD
 280        workspace: Option<Entity<Workspace>>,
 281=======
 282>>>>>>> main
 283        worktree: Option<Entity<Worktree>>,
 284        client: Arc<Client>,
 285        user_store: Entity<UserStore>,
 286        cx: &mut App,
 287    ) -> Entity<Self> {
 288        let this = Self::global(cx).unwrap_or_else(|| {
 289            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 290            cx.set_global(ZetaGlobal(entity.clone()));
 291            entity
 292        });
 293
 294        this.update(cx, move |this, cx| {
 295            if let Some(worktree) = worktree {
 296                let worktree_id = worktree.read(cx).id();
 297                this.license_detection_watchers
 298                    .entry(worktree_id)
 299                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 300            }
 301        });
 302
 303        this
 304    }
 305
 306    pub fn clear_history(&mut self) {
 307        self.events.clear();
 308    }
 309
 310    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 311        self.user_store.read(cx).edit_prediction_usage()
 312    }
 313
 314<<<<<<< HEAD
 315    fn new(
 316        workspace: Option<Entity<Workspace>>,
 317        client: Arc<Client>,
 318        user_store: Entity<UserStore>,
 319        cx: &mut Context<Self>,
 320    ) -> Self {
 321=======
 322    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 323>>>>>>> main
 324        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 325
 326        let data_collection_choice = Self::load_data_collection_choices();
 327        let data_collection_choice = cx.new(|_| data_collection_choice);
 328
 329        let mut activity_poll_task = None;
 330
 331        if let Some(workspace) = &workspace {
 332            let project = workspace.read(cx).project().clone();
 333            cx.subscribe(&project, |this, _project, event, cx| match event {
 334                project::Event::ActiveEntryChanged(entry_id) => {
 335                    this.handle_active_project_entry_changed(cx)
 336                }
 337                _ => {}
 338            })
 339            .detach();
 340
 341            // TODO: ideally this would attend to window focus when tracking time, and pause the
 342            // loop for efficiency when not focused.
 343            activity_poll_task = Some(cx.spawn(async move |this, cx| {
 344                let mut instant_before_delay = None;
 345                loop {
 346                    let data_collection_is_enabled = this.read_with(cx, |this, cx| {
 347                        this.data_collection_choice.read(cx).is_enabled()
 348                    })?;
 349                    let interval = if data_collection_is_enabled {
 350                        ACTIVITY_POLL_INTERVAL
 351                    } else {
 352                        instant_before_delay = None;
 353                        DISABLED_ACTIVITY_POLL_INTERVAL
 354                    };
 355                    cx.background_executor().timer(interval).await;
 356                    this.update(cx, |this, cx| {
 357                        let now = Instant::now();
 358                        this.handle_activity_poll(instant_before_delay, now, cx);
 359                        instant_before_delay = Some(now);
 360                    })?
 361                }
 362            }));
 363        }
 364
 365        Self {
 366<<<<<<< HEAD
 367            workspace: workspace.map_or_else(
 368                || WeakEntity::new_invalid(),
 369                |workspace| workspace.downgrade(),
 370            ),
 371=======
 372>>>>>>> main
 373            client,
 374            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 375            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 376            rated_completions: HashSet::default(),
 377            registered_buffers: HashMap::default(),
 378            data_collection_choice,
 379            llm_token: LlmApiToken::default(),
 380            _llm_token_subscription: cx.subscribe(
 381                &refresh_llm_token_listener,
 382                |this, _listener, _event, cx| {
 383                    let client = this.client.clone();
 384                    let llm_token = this.llm_token.clone();
 385                    cx.spawn(async move |_this, _cx| {
 386                        llm_token.refresh(&client).await?;
 387                        anyhow::Ok(())
 388                    })
 389                    .detach_and_log_err(cx);
 390                },
 391            ),
 392            update_required: false,
 393            license_detection_watchers: HashMap::default(),
 394            user_store,
 395            recent_editors: VecDeque::new(),
 396            last_activity_state: None,
 397            _activity_poll_task: activity_poll_task,
 398        }
 399    }
 400
 401    fn push_event(&mut self, event: Event) {
 402        if let Some(Event::BufferChange {
 403            new_snapshot: last_new_snapshot,
 404            timestamp: last_timestamp,
 405            ..
 406        }) = self.events.back_mut()
 407        {
 408            // Coalesce edits for the same buffer when they happen one after the other.
 409            let Event::BufferChange {
 410                old_snapshot,
 411                new_snapshot,
 412                timestamp,
 413            } = &event;
 414
 415            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 416                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 417                && old_snapshot.version == last_new_snapshot.version
 418            {
 419                *last_new_snapshot = new_snapshot.clone();
 420                *last_timestamp = *timestamp;
 421                return;
 422            }
 423        }
 424
 425        if self.events.len() >= MAX_EVENT_COUNT {
 426            // These are halved instead of popping to improve prompt caching.
 427            self.events.drain(..MAX_EVENT_COUNT / 2);
 428        }
 429
 430        self.events.push_back(event);
 431    }
 432
 433    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 434        let buffer_id = buffer.entity_id();
 435        let weak_buffer = buffer.downgrade();
 436
 437        if let std::collections::hash_map::Entry::Vacant(entry) =
 438            self.registered_buffers.entry(buffer_id)
 439        {
 440            let snapshot = buffer.read(cx).snapshot();
 441
 442            entry.insert(RegisteredBuffer {
 443                snapshot,
 444                _subscriptions: [
 445                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 446                        this.handle_buffer_event(buffer, event, cx);
 447                    }),
 448                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 449                        this.registered_buffers.remove(&weak_buffer.entity_id());
 450                    }),
 451                ],
 452            });
 453        };
 454    }
 455
 456    fn handle_buffer_event(
 457        &mut self,
 458        buffer: Entity<Buffer>,
 459        event: &language::BufferEvent,
 460        cx: &mut Context<Self>,
 461    ) {
 462        if let language::BufferEvent::Edited = event {
 463            self.report_changes_for_buffer(&buffer, cx);
 464        }
 465    }
 466
 467    fn request_completion_impl<F, R>(
 468        &mut self,
 469<<<<<<< HEAD
 470        workspace: Option<Entity<Workspace>>,
 471        project: Option<Entity<Project>>,
 472=======
 473        project: Option<&Entity<Project>>,
 474>>>>>>> main
 475        buffer: &Entity<Buffer>,
 476        cursor: language::Anchor,
 477        can_collect_data: CanCollectData,
 478        cx: &mut Context<Self>,
 479        perform_predict_edits: F,
 480    ) -> Task<Result<Option<EditPrediction>>>
 481    where
 482        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 483        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 484            + Send
 485            + 'static,
 486    {
 487        let buffer = buffer.clone();
 488        let buffer_snapshotted_at = Instant::now();
 489        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 490        let zeta = cx.entity();
 491        let events = self.events.clone();
 492        let client = self.client.clone();
 493        let llm_token = self.llm_token.clone();
 494        let app_version = AppVersion::global(cx);
 495
 496        let full_path: Arc<Path> = snapshot
 497            .file()
 498            .map(|f| Arc::from(f.full_path(cx).as_path()))
 499            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 500        let full_path_str = full_path.to_string_lossy().to_string();
 501        let cursor_point = cursor.to_point(&snapshot);
 502        let cursor_offset = cursor_point.to_offset(&snapshot);
 503        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 504        let gather_task = cx.background_spawn(gather_context(
 505            full_path_str,
 506            snapshot.clone(),
 507            cursor_point,
 508            make_events_prompt,
 509            can_collect_data,
 510        ));
 511
 512        cx.spawn(async move |this, cx| {
 513            let GatherContextOutput {
 514                body,
 515                editable_range,
 516            } = gather_task.await?;
 517            let done_gathering_context_at = Instant::now();
 518
 519            let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
 520                && let Some(file) = snapshot.file()
 521                && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
 522            {
 523                // This is async to reduce latency of the edit predictions request. The downside is
 524                // that it will see a slightly later state than was used when gathering context.
 525                let snapshot = snapshot.clone();
 526                let this = this.clone();
 527                Some(cx.spawn(async move |cx| {
 528                    if let Ok(Some(task)) = this.update(cx, |this, cx| {
 529                        this.gather_additional_context(
 530                            cursor_point,
 531                            cursor_offset,
 532                            snapshot,
 533                            &buffer_snapshotted_at,
 534                            project_path,
 535                            project.as_ref(),
 536                            cx,
 537                        )
 538                    }) {
 539                        Some(task.await)
 540                    } else {
 541                        None
 542                    }
 543                }))
 544            } else {
 545                None
 546            };
 547
 548            log::debug!(
 549                "Events:\n{}\nExcerpt:\n{:?}",
 550                body.input_events,
 551                body.input_excerpt
 552            );
 553
 554            let input_events = body.input_events.clone();
 555            let input_excerpt = body.input_excerpt.clone();
 556
 557            let response = perform_predict_edits(PerformPredictEditsParams {
 558                client,
 559                llm_token,
 560                app_version,
 561                body,
 562            })
 563            .await;
 564            let (response, usage) = match response {
 565                Ok(response) => response,
 566                Err(err) => {
 567                    if err.is::<ZedUpdateRequiredError>() {
 568                        cx.update(|cx| {
 569                            zeta.update(cx, |zeta, _cx| {
 570                                zeta.update_required = true;
 571                            });
 572
 573                            let error_message: SharedString = err.to_string().into();
 574                            show_app_notification(
 575                                NotificationId::unique::<ZedUpdateRequiredError>(),
 576                                cx,
 577                                move |cx| {
 578                                    cx.new(|cx| {
 579                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 580                                            .with_link_button(
 581                                                "Update Zed",
 582                                                "https://zed.dev/releases",
 583                                            )
 584                                    })
 585                                },
 586                            );
 587                        })
 588                        .ok();
 589                    }
 590
 591                    return Err(err);
 592                }
 593            };
 594
 595            let received_response_at = Instant::now();
 596            log::debug!("completion response: {}", &response.output_excerpt);
 597
 598            if let Some(usage) = usage {
 599                this.update(cx, |this, cx| {
 600                    this.user_store.update(cx, |user_store, cx| {
 601                        user_store.update_edit_prediction_usage(usage, cx);
 602                    });
 603                })
 604                .ok();
 605            }
 606
 607            let request_id = response.request_id.clone();
 608            let edit_prediction = Self::process_completion_response(
 609                response,
 610                buffer,
 611                &snapshot,
 612                editable_range,
 613                cursor_offset,
 614                full_path,
 615                input_events,
 616                input_excerpt,
 617                buffer_snapshotted_at,
 618                cx,
 619            )
 620            .await;
 621
 622            let finished_at = Instant::now();
 623
 624            // record latency for ~1% of requests
 625            if rand::random::<u8>() <= 2 {
 626                telemetry::event!(
 627                    "Edit Prediction Request",
 628                    context_latency = done_gathering_context_at
 629                        .duration_since(buffer_snapshotted_at)
 630                        .as_millis(),
 631                    request_latency = received_response_at
 632                        .duration_since(done_gathering_context_at)
 633                        .as_millis(),
 634                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 635                );
 636            }
 637
 638            if let Some(additional_context_task) = additional_context_task {
 639                cx.background_spawn(async move {
 640                    if let Some(additional_context) = additional_context_task.await {
 641                        telemetry::event!(
 642                            "Edit Prediction Additional Context",
 643                            request_id = request_id,
 644                            additional_context = additional_context
 645                        );
 646                    }
 647                })
 648                .detach();
 649            }
 650
 651            edit_prediction
 652        })
 653    }
 654
 655    // Generates several example completions of various states to fill the Zeta completion modal
 656    #[cfg(any(test, feature = "test-support"))]
 657    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 658        use language::Point;
 659
 660        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 661            And maybe a short line
 662
 663            Then a few lines
 664
 665            and then another
 666            "#};
 667
 668        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 669        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 670
 671        let completion_tasks = vec![
 672            self.fake_completion(
 673                None,
 674                &buffer,
 675                position,
 676                PredictEditsResponse {
 677                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 678                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 679a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 680[here's an edit]
 681And maybe a short line
 682Then a few lines
 683and then another
 684{EDITABLE_REGION_END_MARKER}
 685                        "),
 686                },
 687                cx,
 688            ),
 689            self.fake_completion(
 690                None,
 691                &buffer,
 692                position,
 693                PredictEditsResponse {
 694                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 695                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 696a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 697And maybe a short line
 698[and another edit]
 699Then a few lines
 700and then another
 701{EDITABLE_REGION_END_MARKER}
 702                        "#),
 703                },
 704                cx,
 705            ),
 706            self.fake_completion(
 707                None,
 708                &buffer,
 709                position,
 710                PredictEditsResponse {
 711                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 712                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 713a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 714And maybe a short line
 715
 716Then a few lines
 717
 718and then another
 719{EDITABLE_REGION_END_MARKER}
 720                        "#),
 721                },
 722                cx,
 723            ),
 724            self.fake_completion(
 725                None,
 726                &buffer,
 727                position,
 728                PredictEditsResponse {
 729                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 730                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 731a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 732And maybe a short line
 733
 734Then a few lines
 735
 736and then another
 737{EDITABLE_REGION_END_MARKER}
 738                        "#),
 739                },
 740                cx,
 741            ),
 742            self.fake_completion(
 743                None,
 744                &buffer,
 745                position,
 746                PredictEditsResponse {
 747                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 748                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 749a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 750And maybe a short line
 751Then a few lines
 752[a third completion]
 753and then another
 754{EDITABLE_REGION_END_MARKER}
 755                        "#),
 756                },
 757                cx,
 758            ),
 759            self.fake_completion(
 760                None,
 761                &buffer,
 762                position,
 763                PredictEditsResponse {
 764                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 765                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 766a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 767And maybe a short line
 768and then another
 769[fourth completion example]
 770{EDITABLE_REGION_END_MARKER}
 771                        "#),
 772                },
 773                cx,
 774            ),
 775            self.fake_completion(
 776                None,
 777                &buffer,
 778                position,
 779                PredictEditsResponse {
 780                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 781                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 782a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 783And maybe a short line
 784Then a few lines
 785and then another
 786[fifth and final completion]
 787{EDITABLE_REGION_END_MARKER}
 788                        "#),
 789                },
 790                cx,
 791            ),
 792        ];
 793
 794        cx.spawn(async move |zeta, cx| {
 795            for task in completion_tasks {
 796                task.await.unwrap();
 797            }
 798
 799            zeta.update(cx, |zeta, _cx| {
 800                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 801                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 802            })
 803            .ok();
 804        })
 805    }
 806
 807    #[cfg(any(test, feature = "test-support"))]
 808    pub fn fake_completion(
 809        &mut self,
 810        project: Option<Entity<Project>>,
 811        buffer: &Entity<Buffer>,
 812        position: language::Anchor,
 813        response: PredictEditsResponse,
 814        cx: &mut Context<Self>,
 815    ) -> Task<Result<Option<EditPrediction>>> {
 816        use std::future::ready;
 817
 818<<<<<<< HEAD
 819        self.request_completion_impl(
 820            None,
 821            project,
 822            buffer,
 823            position,
 824            CanCollectData(false),
 825            cx,
 826            |_params| ready(Ok((response, None))),
 827        )
 828=======
 829        self.request_completion_impl(project, buffer, position, false, cx, |_params| {
 830            ready(Ok((response, None)))
 831        })
 832>>>>>>> main
 833    }
 834
 835    pub fn request_completion(
 836        &mut self,
 837        project: Option<Entity<Project>>,
 838        buffer: &Entity<Buffer>,
 839        position: language::Anchor,
 840        can_collect_data: CanCollectData,
 841        cx: &mut Context<Self>,
 842    ) -> Task<Result<Option<EditPrediction>>> {
 843        self.request_completion_impl(
 844<<<<<<< HEAD
 845            self.workspace.upgrade(),
 846=======
 847>>>>>>> main
 848            project,
 849            buffer,
 850            position,
 851            can_collect_data,
 852            cx,
 853            Self::perform_predict_edits,
 854        )
 855    }
 856
 857    pub fn perform_predict_edits(
 858        params: PerformPredictEditsParams,
 859    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 860        async move {
 861            let PerformPredictEditsParams {
 862                client,
 863                llm_token,
 864                app_version,
 865                body,
 866                ..
 867            } = params;
 868
 869            let http_client = client.http_client();
 870            let mut token = llm_token.acquire(&client).await?;
 871            let mut did_retry = false;
 872
 873            loop {
 874                let request_builder = http_client::Request::builder().method(Method::POST);
 875                let request_builder =
 876                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 877                        request_builder.uri(predict_edits_url)
 878                    } else {
 879                        request_builder.uri(
 880                            http_client
 881                                .build_zed_llm_url("/predict_edits/v2", &[])?
 882                                .as_ref(),
 883                        )
 884                    };
 885                let request = request_builder
 886                    .header("Content-Type", "application/json")
 887                    .header("Authorization", format!("Bearer {}", token))
 888                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 889                    .body(serde_json::to_string(&body)?.into())?;
 890
 891                let mut response = http_client.send(request).await?;
 892
 893                if let Some(minimum_required_version) = response
 894                    .headers()
 895                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 896                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 897                {
 898                    anyhow::ensure!(
 899                        app_version >= minimum_required_version,
 900                        ZedUpdateRequiredError {
 901                            minimum_version: minimum_required_version
 902                        }
 903                    );
 904                }
 905
 906                if response.status().is_success() {
 907                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 908
 909                    let mut body = String::new();
 910                    response.body_mut().read_to_string(&mut body).await?;
 911                    return Ok((serde_json::from_str(&body)?, usage));
 912                } else if !did_retry
 913                    && response
 914                        .headers()
 915                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 916                        .is_some()
 917                {
 918                    did_retry = true;
 919                    token = llm_token.refresh(&client).await?;
 920                } else {
 921                    let mut body = String::new();
 922                    response.body_mut().read_to_string(&mut body).await?;
 923                    anyhow::bail!(
 924                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 925                        response.status(),
 926                        body
 927                    );
 928                }
 929            }
 930        }
 931    }
 932
 933    fn accept_edit_prediction(
 934        &mut self,
 935        request_id: EditPredictionId,
 936        cx: &mut Context<Self>,
 937    ) -> Task<Result<()>> {
 938        let client = self.client.clone();
 939        let llm_token = self.llm_token.clone();
 940        let app_version = AppVersion::global(cx);
 941        cx.spawn(async move |this, cx| {
 942            let http_client = client.http_client();
 943            let mut response = llm_token_retry(&llm_token, &client, |token| {
 944                let request_builder = http_client::Request::builder().method(Method::POST);
 945                let request_builder =
 946                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 947                        request_builder.uri(accept_prediction_url)
 948                    } else {
 949                        request_builder.uri(
 950                            http_client
 951                                .build_zed_llm_url("/predict_edits/accept", &[])?
 952                                .as_ref(),
 953                        )
 954                    };
 955                Ok(request_builder
 956                    .header("Content-Type", "application/json")
 957                    .header("Authorization", format!("Bearer {}", token))
 958                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 959                    .body(
 960                        serde_json::to_string(&AcceptEditPredictionBody {
 961                            request_id: request_id.0,
 962                        })?
 963                        .into(),
 964                    )?)
 965            })
 966            .await?;
 967
 968            if let Some(minimum_required_version) = response
 969                .headers()
 970                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 971                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 972                && app_version < minimum_required_version
 973            {
 974                return Err(anyhow!(ZedUpdateRequiredError {
 975                    minimum_version: minimum_required_version
 976                }));
 977            }
 978
 979            if response.status().is_success() {
 980                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 981                    this.update(cx, |this, cx| {
 982                        this.user_store.update(cx, |user_store, cx| {
 983                            user_store.update_edit_prediction_usage(usage, cx);
 984                        });
 985                    })?;
 986                }
 987
 988                Ok(())
 989            } else {
 990                let mut body = String::new();
 991                response.body_mut().read_to_string(&mut body).await?;
 992                Err(anyhow!(
 993                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 994                    response.status(),
 995                    body
 996                ))
 997            }
 998        })
 999    }
1000
1001    fn process_completion_response(
1002        prediction_response: PredictEditsResponse,
1003        buffer: Entity<Buffer>,
1004        snapshot: &BufferSnapshot,
1005        editable_range: Range<usize>,
1006        cursor_offset: usize,
1007        path: Arc<Path>,
1008        input_events: String,
1009        input_excerpt: String,
1010        buffer_snapshotted_at: Instant,
1011        cx: &AsyncApp,
1012    ) -> Task<Result<Option<EditPrediction>>> {
1013        let snapshot = snapshot.clone();
1014        let request_id = prediction_response.request_id;
1015        let output_excerpt = prediction_response.output_excerpt;
1016        cx.spawn(async move |cx| {
1017            let output_excerpt: Arc<str> = output_excerpt.into();
1018
1019            let edits: Arc<[(Range<Anchor>, String)]> = cx
1020                .background_spawn({
1021                    let output_excerpt = output_excerpt.clone();
1022                    let editable_range = editable_range.clone();
1023                    let snapshot = snapshot.clone();
1024                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
1025                })
1026                .await?
1027                .into();
1028
1029            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
1030                let edits = edits.clone();
1031                |buffer, cx| {
1032                    let new_snapshot = buffer.snapshot();
1033                    let edits: Arc<[(Range<Anchor>, String)]> =
1034                        interpolate(&snapshot, &new_snapshot, edits)?.into();
1035                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
1036                }
1037            })?
1038            else {
1039                return anyhow::Ok(None);
1040            };
1041
1042            let edit_preview = edit_preview.await;
1043
1044            Ok(Some(EditPrediction {
1045                id: EditPredictionId(request_id),
1046                path,
1047                excerpt_range: editable_range,
1048                cursor_offset,
1049                edits,
1050                edit_preview,
1051                snapshot,
1052                input_events: input_events.into(),
1053                input_excerpt: input_excerpt.into(),
1054                output_excerpt,
1055                buffer_snapshotted_at,
1056                response_received_at: Instant::now(),
1057            }))
1058        })
1059    }
1060
1061    fn parse_edits(
1062        output_excerpt: Arc<str>,
1063        editable_range: Range<usize>,
1064        snapshot: &BufferSnapshot,
1065    ) -> Result<Vec<(Range<Anchor>, String)>> {
1066        let content = output_excerpt.replace(CURSOR_MARKER, "");
1067
1068        let start_markers = content
1069            .match_indices(EDITABLE_REGION_START_MARKER)
1070            .collect::<Vec<_>>();
1071        anyhow::ensure!(
1072            start_markers.len() == 1,
1073            "expected exactly one start marker, found {}",
1074            start_markers.len()
1075        );
1076
1077        let end_markers = content
1078            .match_indices(EDITABLE_REGION_END_MARKER)
1079            .collect::<Vec<_>>();
1080        anyhow::ensure!(
1081            end_markers.len() == 1,
1082            "expected exactly one end marker, found {}",
1083            end_markers.len()
1084        );
1085
1086        let sof_markers = content
1087            .match_indices(START_OF_FILE_MARKER)
1088            .collect::<Vec<_>>();
1089        anyhow::ensure!(
1090            sof_markers.len() <= 1,
1091            "expected at most one start-of-file marker, found {}",
1092            sof_markers.len()
1093        );
1094
1095        let codefence_start = start_markers[0].0;
1096        let content = &content[codefence_start..];
1097
1098        let newline_ix = content.find('\n').context("could not find newline")?;
1099        let content = &content[newline_ix + 1..];
1100
1101        let codefence_end = content
1102            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1103            .context("could not find end marker")?;
1104        let new_text = &content[..codefence_end];
1105
1106        let old_text = snapshot
1107            .text_for_range(editable_range.clone())
1108            .collect::<String>();
1109
1110        Ok(Self::compute_edits(
1111            old_text,
1112            new_text,
1113            editable_range.start,
1114            snapshot,
1115        ))
1116    }
1117
1118    pub fn compute_edits(
1119        old_text: String,
1120        new_text: &str,
1121        offset: usize,
1122        snapshot: &BufferSnapshot,
1123    ) -> Vec<(Range<Anchor>, String)> {
1124        text_diff(&old_text, new_text)
1125            .into_iter()
1126            .map(|(mut old_range, new_text)| {
1127                old_range.start += offset;
1128                old_range.end += offset;
1129
1130                let prefix_len = common_prefix(
1131                    snapshot.chars_for_range(old_range.clone()),
1132                    new_text.chars(),
1133                );
1134                old_range.start += prefix_len;
1135
1136                let suffix_len = common_prefix(
1137                    snapshot.reversed_chars_for_range(old_range.clone()),
1138                    new_text[prefix_len..].chars().rev(),
1139                );
1140                old_range.end = old_range.end.saturating_sub(suffix_len);
1141
1142                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1143                let range = if old_range.is_empty() {
1144                    let anchor = snapshot.anchor_after(old_range.start);
1145                    anchor..anchor
1146                } else {
1147                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1148                };
1149                (range, new_text)
1150            })
1151            .collect()
1152    }
1153
1154    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1155        self.rated_completions.contains(&completion_id)
1156    }
1157
1158    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1159        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1160            let completion = self.shown_completions.pop_back().unwrap();
1161            self.rated_completions.remove(&completion.id);
1162        }
1163        self.shown_completions.push_front(completion.clone());
1164        cx.notify();
1165    }
1166
1167    pub fn rate_completion(
1168        &mut self,
1169        completion: &EditPrediction,
1170        rating: EditPredictionRating,
1171        feedback: String,
1172        cx: &mut Context<Self>,
1173    ) {
1174        self.rated_completions.insert(completion.id);
1175        telemetry::event!(
1176            "Edit Prediction Rated",
1177            rating,
1178            input_events = completion.input_events,
1179            input_excerpt = completion.input_excerpt,
1180            output_excerpt = completion.output_excerpt,
1181            feedback
1182        );
1183        self.client.telemetry().flush_events().detach();
1184        cx.notify();
1185    }
1186
1187    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1188        self.shown_completions.iter()
1189    }
1190
1191    pub fn shown_completions_len(&self) -> usize {
1192        self.shown_completions.len()
1193    }
1194
1195    fn report_changes_for_buffer(
1196        &mut self,
1197        buffer: &Entity<Buffer>,
1198        cx: &mut Context<Self>,
1199    ) -> BufferSnapshot {
1200        self.register_buffer(buffer, cx);
1201
1202        let registered_buffer = self
1203            .registered_buffers
1204            .get_mut(&buffer.entity_id())
1205            .unwrap();
1206        let new_snapshot = buffer.read(cx).snapshot();
1207
1208        if new_snapshot.version != registered_buffer.snapshot.version {
1209            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1210            self.push_event(Event::BufferChange {
1211                old_snapshot,
1212                new_snapshot: new_snapshot.clone(),
1213                timestamp: Instant::now(),
1214            });
1215        }
1216
1217        new_snapshot
1218    }
1219
1220    fn load_data_collection_choices() -> DataCollectionChoice {
1221        let choice = KEY_VALUE_STORE
1222            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1223            .log_err()
1224            .flatten();
1225
1226        match choice.as_deref() {
1227            Some("true") => DataCollectionChoice::Enabled,
1228            Some("false") => DataCollectionChoice::Disabled,
1229            Some(_) => {
1230                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1231                DataCollectionChoice::NotAnswered
1232            }
1233            None => DataCollectionChoice::NotAnswered,
1234        }
1235    }
1236
1237    fn gather_additional_context(
1238        &mut self,
1239        cursor_point: language::Point,
1240        cursor_offset: usize,
1241        snapshot: BufferSnapshot,
1242        buffer_snapshotted_at: &Instant,
1243        project_path: ProjectPath,
1244        project: Option<&Entity<Project>>,
1245        cx: &mut Context<Self>,
1246    ) -> Option<Task<PredictEditsAdditionalContext>> {
1247        let project = project?.read(cx);
1248        let entry = project.entry_for_path(&project_path, cx)?;
1249        if !worktree_entry_is_eligible_for_collection(&entry) {
1250            return None;
1251        }
1252
1253        let git_store = project.git_store().read(cx);
1254        let (repository, repo_path) =
1255            git_store.repository_and_path_for_project_path(&project_path, cx)?;
1256        let repo_path_string = repo_path.to_str()?.to_string();
1257
1258        let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
1259            snapshot
1260                .diagnostics
1261                .iter()
1262                .filter_map(|(language_server_id, diagnostics)| {
1263                    let language_server =
1264                        local_lsp_store.running_language_server_for_id(*language_server_id)?;
1265                    Some((
1266                        *language_server_id,
1267                        language_server.name(),
1268                        diagnostics.clone(),
1269                    ))
1270                })
1271                .collect()
1272        } else {
1273            Vec::new()
1274        };
1275
1276        repository.update(cx, |repository, cx| {
1277            let head_sha = repository.head_commit.as_ref()?.sha.to_string();
1278            let remote_origin_url = repository.remote_origin_url.clone();
1279            let remote_upstream_url = repository.remote_upstream_url.clone();
1280            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1281
1282            // group, resolve, and select diagnostics on a background thread
1283            Some(cx.background_spawn(async move {
1284                let mut diagnostic_groups_with_name = Vec::new();
1285                for (language_server_id, language_server_name, diagnostics) in
1286                    diagnostics.into_iter()
1287                {
1288                    let mut groups = Vec::new();
1289                    diagnostics.groups(language_server_id, &mut groups, &snapshot);
1290                    diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
1291                        (
1292                            language_server_name.clone(),
1293                            group.resolve::<usize>(&snapshot),
1294                        )
1295                    }));
1296                }
1297
1298                // sort by proximity to cursor
1299                diagnostic_groups_with_name.sort_by_key(|(_, group)| {
1300                    let range = &group.entries[group.primary_ix].range;
1301                    if range.start >= cursor_offset {
1302                        range.start - cursor_offset
1303                    } else if cursor_offset >= range.end {
1304                        cursor_offset - range.end
1305                    } else {
1306                        (cursor_offset - range.start).min(range.end - cursor_offset)
1307                    }
1308                });
1309
1310                let mut diagnostic_groups = Vec::new();
1311                let mut diagnostic_groups_truncated = false;
1312                let mut diagnostics_byte_count = 0;
1313                for (name, group) in diagnostic_groups_with_name {
1314                    let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1315                    diagnostics_byte_count += name.0.len() + raw_value.get().len();
1316                    if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
1317                        diagnostic_groups_truncated = true;
1318                        break;
1319                    }
1320                    diagnostic_groups.push((name.to_string(), raw_value));
1321                }
1322
1323                PredictEditsAdditionalContext {
1324                    input_path: repo_path_string,
1325                    cursor_point: to_cloud_llm_client_point(cursor_point),
1326                    cursor_offset: cursor_offset,
1327                    git_info: PredictEditsGitInfo {
1328                        head_sha: Some(head_sha),
1329                        remote_origin_url,
1330                        remote_upstream_url,
1331                    },
1332                    diagnostic_groups,
1333                    diagnostic_groups_truncated,
1334                    recent_files,
1335                }
1336            }))
1337        })
1338    }
1339
1340    fn handle_active_project_entry_changed(&mut self, cx: &mut Context<Self>) {
1341        if !self.data_collection_choice.read(cx).is_enabled() {
1342            self.recent_editors.clear();
1343            self.last_activity_state = None;
1344            return;
1345        }
1346        if let Some(active_editor) = self
1347            .workspace
1348            .read_with(cx, |workspace, cx| {
1349                workspace
1350                    .active_item(cx)
1351                    .and_then(|item| item.act_as::<Editor>(cx))
1352            })
1353            .ok()
1354            .flatten()
1355        {
1356            let now = Instant::now();
1357            let editor = active_editor.downgrade();
1358            let existing_recent_editor = if let Some(existing_ix) = self
1359                .recent_editors
1360                .iter()
1361                .rposition(|recent| &recent.editor == &editor)
1362            {
1363                if existing_ix + 1 != self.recent_editors.len() {
1364                    self.last_activity_state = None;
1365                }
1366                self.recent_editors.remove(existing_ix)
1367            } else {
1368                None
1369            };
1370            let new_recent = RecentEditor {
1371                editor: active_editor.downgrade(),
1372                last_active_at: now,
1373                activation_count: existing_recent_editor
1374                    .as_ref()
1375                    .map_or(0, |recent| recent.activation_count + 1),
1376                cumulative_time_navigating: existing_recent_editor
1377                    .as_ref()
1378                    .map_or(Duration::ZERO, |recent| recent.cumulative_time_navigating),
1379                cumulative_time_editing: existing_recent_editor
1380                    .map_or(Duration::ZERO, |recent| recent.cumulative_time_editing),
1381            };
1382            // filter out rapid changes in active item, particularly since this can happen rapidly when
1383            // a workspace is loaded.
1384            if let Some(previous_recent) = self.recent_editors.back_mut()
1385                && previous_recent.activation_count == 1
1386                && now.duration_since(previous_recent.last_active_at)
1387                    < MIN_TIME_BETWEEN_RECENT_FILES
1388            {
1389                *previous_recent = new_recent;
1390                return;
1391            }
1392            if self.recent_editors.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1393                self.recent_editors.pop_front();
1394            }
1395            self.recent_editors.push_back(new_recent);
1396        }
1397    }
1398
1399    fn handle_activity_poll(
1400        &mut self,
1401        instant_before_delay: Option<Instant>,
1402        now: Instant,
1403        cx: &mut Context<Self>,
1404    ) {
1405        if !self.data_collection_choice.read(cx).is_enabled() {
1406            self.last_activity_state = None;
1407            return;
1408        }
1409        if let Some(recent_editor) = self.recent_editors.back()
1410            && let Some(editor) = recent_editor.editor.upgrade()
1411        {
1412            let (scroll_position, cursor_point, singleton_version) =
1413                editor.update(cx, |editor, cx| {
1414                    let scroll_position = editor.scroll_position(cx);
1415                    let cursor_point = editor.selections.newest(cx).head();
1416                    let singleton_version = editor
1417                        .buffer()
1418                        .read(cx)
1419                        .as_singleton()
1420                        .map(|singleton_buffer| singleton_buffer.read(cx).version());
1421                    (scroll_position, cursor_point, singleton_version)
1422                });
1423
1424            let navigated = if let Some(last_activity_state) = &self.last_activity_state {
1425                last_activity_state.scroll_position != scroll_position
1426                    || last_activity_state.cursor_point != cursor_point
1427            } else {
1428                false
1429            };
1430
1431            let edited = if let Some(singleton_version) = &singleton_version
1432                && let Some(last_activity_state) = &self.last_activity_state
1433                && let Some(last_singleton_version) = &last_activity_state.singleton_version
1434            {
1435                singleton_version.changed_since(last_singleton_version)
1436            } else {
1437                false
1438            };
1439
1440            self.last_activity_state = Some(ActivityState {
1441                scroll_position,
1442                cursor_point,
1443                singleton_version,
1444            });
1445
1446            let prior_recent_editor = if self.recent_editors.len() > 1 {
1447                Some(&self.recent_editors[self.recent_editors.len() - 2])
1448            } else {
1449                None
1450            };
1451            let additional_time: Option<Duration> =
1452                instant_before_delay.map(|instant_before_delay| {
1453                    now.duration_since(prior_recent_editor.map_or(
1454                        instant_before_delay,
1455                        |prior_recent_editor| {
1456                            prior_recent_editor.last_active_at.max(instant_before_delay)
1457                        },
1458                    ))
1459                });
1460
1461            if let Some(additional_time) = additional_time {
1462                let recent_editor = self.recent_editors.back_mut().unwrap();
1463                if navigated {
1464                    recent_editor.cumulative_time_navigating += additional_time;
1465                }
1466                if edited {
1467                    recent_editor.cumulative_time_editing += additional_time;
1468                }
1469            }
1470        }
1471    }
1472
1473    fn recent_files(
1474        &mut self,
1475        now: &Instant,
1476        repository: &Repository,
1477        cx: &mut App,
1478    ) -> Vec<PredictEditsRecentFile> {
1479        let Ok(project) = self
1480            .workspace
1481            .read_with(cx, |workspace, _cx| workspace.project().clone())
1482        else {
1483            return Vec::new();
1484        };
1485        let mut results = Vec::with_capacity(self.recent_editors.len());
1486        for ix in (0..self.recent_editors.len()).rev() {
1487            let recent_editor = &self.recent_editors[ix];
1488            let keep_entry = recent_editor
1489                .editor
1490                .update(cx, |editor, cx| {
1491                    maybe!({
1492                        let cursor = editor.selections.newest::<MultiBufferPoint>(cx).head();
1493                        let multibuffer = editor.buffer().read(cx);
1494                        let (buffer, cursor_point, _) =
1495                            multibuffer.point_to_buffer_point(cursor, cx)?;
1496                        let file = buffer.read(cx).file()?;
1497                        if !file_is_eligible_for_collection(file.as_ref()) {
1498                            return None;
1499                        }
1500                        let project_path = ProjectPath {
1501                            worktree_id: file.worktree_id(cx),
1502                            path: file.path().clone(),
1503                        };
1504                        let entry = project.read(cx).entry_for_path(&project_path, cx)?;
1505                        if !worktree_entry_is_eligible_for_collection(entry) {
1506                            return None;
1507                        }
1508                        let Some(repo_path) =
1509                            repository.project_path_to_repo_path(&project_path, cx)
1510                        else {
1511                            // entry not removed since later queries may involve other repositories
1512                            return Some(());
1513                        };
1514                        // paths may not be valid UTF-8
1515                        let repo_path_str = repo_path.to_str()?;
1516                        if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1517                            return None;
1518                        }
1519                        let active_to_now_ms = now
1520                            .duration_since(recent_editor.last_active_at)
1521                            .as_millis()
1522                            .try_into()
1523                            .ok()?;
1524                        let cumulative_time_editing_ms = recent_editor
1525                            .cumulative_time_editing
1526                            .as_millis()
1527                            .try_into()
1528                            .ok()?;
1529                        let cumulative_time_navigating_ms = recent_editor
1530                            .cumulative_time_navigating
1531                            .as_millis()
1532                            .try_into()
1533                            .ok()?;
1534                        results.push(PredictEditsRecentFile {
1535                            path: repo_path_str.to_string(),
1536                            cursor_point: to_cloud_llm_client_point(cursor_point),
1537                            active_to_now_ms,
1538                            activation_count: recent_editor.activation_count,
1539                            cumulative_time_editing_ms,
1540                            cumulative_time_navigating_ms,
1541                            is_multibuffer: !multibuffer.is_singleton(),
1542                        });
1543                        Some(())
1544                    })
1545                })
1546                .ok()
1547                .flatten();
1548            if keep_entry.is_none() {
1549                self.recent_editors.remove(ix);
1550            }
1551        }
1552        results
1553    }
1554}
1555
1556fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
1557    cloud_llm_client::Point {
1558        row: point.row,
1559        column: point.column,
1560    }
1561}
1562
1563fn file_is_eligible_for_collection(file: &dyn File) -> bool {
1564    file.is_local() && !file.is_private()
1565}
1566
1567fn worktree_entry_is_eligible_for_collection(entry: &worktree::Entry) -> bool {
1568    entry.is_file()
1569        && entry.is_created()
1570        && !entry.is_ignored
1571        && !entry.is_private
1572        && !entry.is_external
1573        && !entry.is_fifo
1574}
1575
1576pub struct PerformPredictEditsParams {
1577    pub client: Arc<Client>,
1578    pub llm_token: LlmApiToken,
1579    pub app_version: SemanticVersion,
1580    pub body: PredictEditsBody,
1581}
1582
1583#[derive(Error, Debug)]
1584#[error(
1585    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1586)]
1587pub struct ZedUpdateRequiredError {
1588    minimum_version: SemanticVersion,
1589}
1590
1591fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1592    a.zip(b)
1593        .take_while(|(a, b)| a == b)
1594        .map(|(a, _)| a.len_utf8())
1595        .sum()
1596}
1597
1598pub struct GatherContextOutput {
1599    pub body: PredictEditsBody,
1600    pub editable_range: Range<usize>,
1601}
1602
1603pub async fn gather_context(
1604    full_path_str: String,
1605    snapshot: BufferSnapshot,
1606    cursor_point: language::Point,
1607    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1608    can_collect_data: CanCollectData,
1609) -> Result<GatherContextOutput> {
1610    let input_excerpt = excerpt_for_cursor_position(
1611        cursor_point,
1612        &full_path_str,
1613        &snapshot,
1614        MAX_REWRITE_TOKENS,
1615        MAX_CONTEXT_TOKENS,
1616    );
1617    let input_events = make_events_prompt();
1618    let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1619
1620    let body = PredictEditsBody {
1621        input_events,
1622        input_excerpt: input_excerpt.prompt,
1623        can_collect_data: can_collect_data.0,
1624        diagnostic_groups: None,
1625        git_info: None,
1626    };
1627
1628    Ok(GatherContextOutput {
1629        body,
1630        editable_range,
1631    })
1632}
1633
1634fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1635    let mut result = String::new();
1636    for event in events.iter().rev() {
1637        let event_string = event.to_prompt();
1638        let event_tokens = tokens_for_bytes(event_string.len());
1639        if event_tokens > remaining_tokens {
1640            break;
1641        }
1642
1643        if !result.is_empty() {
1644            result.insert_str(0, "\n\n");
1645        }
1646        result.insert_str(0, &event_string);
1647        remaining_tokens -= event_tokens;
1648    }
1649    result
1650}
1651
1652struct RegisteredBuffer {
1653    snapshot: BufferSnapshot,
1654    _subscriptions: [gpui::Subscription; 2],
1655}
1656
1657#[derive(Clone)]
1658pub enum Event {
1659    BufferChange {
1660        old_snapshot: BufferSnapshot,
1661        new_snapshot: BufferSnapshot,
1662        timestamp: Instant,
1663    },
1664}
1665
1666impl Event {
1667    fn to_prompt(&self) -> String {
1668        match self {
1669            Event::BufferChange {
1670                old_snapshot,
1671                new_snapshot,
1672                ..
1673            } => {
1674                let mut prompt = String::new();
1675
1676                let old_path = old_snapshot
1677                    .file()
1678                    .map(|f| f.path().as_ref())
1679                    .unwrap_or(Path::new("untitled"));
1680                let new_path = new_snapshot
1681                    .file()
1682                    .map(|f| f.path().as_ref())
1683                    .unwrap_or(Path::new("untitled"));
1684                if old_path != new_path {
1685                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1686                }
1687
1688                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1689                if !diff.is_empty() {
1690                    write!(
1691                        prompt,
1692                        "User edited {:?}:\n```diff\n{}\n```",
1693                        new_path, diff
1694                    )
1695                    .unwrap();
1696                }
1697
1698                prompt
1699            }
1700        }
1701    }
1702}
1703
1704#[derive(Debug, Clone)]
1705struct CurrentEditPrediction {
1706    buffer_id: EntityId,
1707    completion: EditPrediction,
1708}
1709
1710impl CurrentEditPrediction {
1711    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1712        if self.buffer_id != old_completion.buffer_id {
1713            return true;
1714        }
1715
1716        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1717            return true;
1718        };
1719        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1720            return false;
1721        };
1722
1723        if old_edits.len() == 1 && new_edits.len() == 1 {
1724            let (old_range, old_text) = &old_edits[0];
1725            let (new_range, new_text) = &new_edits[0];
1726            new_range == old_range && new_text.starts_with(old_text)
1727        } else {
1728            true
1729        }
1730    }
1731}
1732
1733struct PendingCompletion {
1734    id: usize,
1735    _task: Task<()>,
1736}
1737
1738#[derive(Debug, Clone, Copy)]
1739pub enum DataCollectionChoice {
1740    NotAnswered,
1741    Enabled,
1742    Disabled,
1743}
1744
1745impl DataCollectionChoice {
1746    pub fn is_enabled(self) -> bool {
1747        match self {
1748            Self::Enabled => true,
1749            Self::NotAnswered | Self::Disabled => false,
1750        }
1751    }
1752
1753    pub fn is_answered(self) -> bool {
1754        match self {
1755            Self::Enabled | Self::Disabled => true,
1756            Self::NotAnswered => false,
1757        }
1758    }
1759
1760    pub fn toggle(&self) -> DataCollectionChoice {
1761        match self {
1762            Self::Enabled => Self::Disabled,
1763            Self::Disabled => Self::Enabled,
1764            Self::NotAnswered => Self::Enabled,
1765        }
1766    }
1767}
1768
1769impl From<bool> for DataCollectionChoice {
1770    fn from(value: bool) -> Self {
1771        match value {
1772            true => DataCollectionChoice::Enabled,
1773            false => DataCollectionChoice::Disabled,
1774        }
1775    }
1776}
1777
1778pub struct ProviderDataCollection {
1779    /// When set to None, data collection is not possible in the provider buffer
1780    choice: Option<Entity<DataCollectionChoice>>,
1781    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1782}
1783
1784#[derive(Debug, Clone, Copy)]
1785pub struct CanCollectData(pub bool);
1786
1787impl ProviderDataCollection {
1788    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1789        let choice_and_watcher = buffer.and_then(|buffer| {
1790            let file = buffer.read(cx).file()?;
1791
1792            if !file_is_eligible_for_collection(file.as_ref()) {
1793                return None;
1794            }
1795
1796            let zeta = zeta.read(cx);
1797            let choice = zeta.data_collection_choice.clone();
1798
1799            let license_detection_watcher = zeta
1800                .license_detection_watchers
1801                .get(&file.worktree_id(cx))
1802                .cloned()?;
1803
1804            Some((choice, license_detection_watcher))
1805        });
1806
1807        if let Some((choice, watcher)) = choice_and_watcher {
1808            ProviderDataCollection {
1809                choice: Some(choice),
1810                license_detection_watcher: Some(watcher),
1811            }
1812        } else {
1813            ProviderDataCollection {
1814                choice: None,
1815                license_detection_watcher: None,
1816            }
1817        }
1818    }
1819
1820    pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1821        CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1822    }
1823
1824    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1825        self.choice
1826            .as_ref()
1827            .is_some_and(|choice| choice.read(cx).is_enabled())
1828    }
1829
1830    fn is_project_open_source(&self) -> bool {
1831        self.license_detection_watcher
1832            .as_ref()
1833            .is_some_and(|watcher| watcher.is_project_open_source())
1834    }
1835
1836    pub fn toggle(&mut self, cx: &mut App) {
1837        if let Some(choice) = self.choice.as_mut() {
1838            let new_choice = choice.update(cx, |choice, _cx| {
1839                let new_choice = choice.toggle();
1840                *choice = new_choice;
1841                new_choice
1842            });
1843
1844            db::write_and_log(cx, move || {
1845                KEY_VALUE_STORE.write_kvp(
1846                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1847                    new_choice.is_enabled().to_string(),
1848                )
1849            });
1850        }
1851    }
1852}
1853
1854async fn llm_token_retry(
1855    llm_token: &LlmApiToken,
1856    client: &Arc<Client>,
1857    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1858) -> Result<Response<AsyncBody>> {
1859    let mut did_retry = false;
1860    let http_client = client.http_client();
1861    let mut token = llm_token.acquire(client).await?;
1862    loop {
1863        let request = build_request(token.clone())?;
1864        let response = http_client.send(request).await?;
1865
1866        if !did_retry
1867            && !response.status().is_success()
1868            && response
1869                .headers()
1870                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1871                .is_some()
1872        {
1873            did_retry = true;
1874            token = llm_token.refresh(client).await?;
1875            continue;
1876        }
1877
1878        return Ok(response);
1879    }
1880}
1881
1882pub struct ZetaEditPredictionProvider {
1883    zeta: Entity<Zeta>,
1884    pending_completions: ArrayVec<PendingCompletion, 2>,
1885    next_pending_completion_id: usize,
1886    current_completion: Option<CurrentEditPrediction>,
1887    /// None if this is entirely disabled for this provider
1888    provider_data_collection: ProviderDataCollection,
1889    last_request_timestamp: Instant,
1890}
1891
1892impl ZetaEditPredictionProvider {
1893    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1894
1895    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1896        Self {
1897            zeta,
1898            pending_completions: ArrayVec::new(),
1899            next_pending_completion_id: 0,
1900            current_completion: None,
1901            provider_data_collection,
1902            last_request_timestamp: Instant::now(),
1903        }
1904    }
1905}
1906
1907impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1908    fn name() -> &'static str {
1909        "zed-predict"
1910    }
1911
1912    fn display_name() -> &'static str {
1913        "Zed's Edit Predictions"
1914    }
1915
1916    fn show_completions_in_menu() -> bool {
1917        true
1918    }
1919
1920    fn show_tab_accept_marker() -> bool {
1921        true
1922    }
1923
1924    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1925        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1926
1927        if self.provider_data_collection.is_data_collection_enabled(cx) {
1928            DataCollectionState::Enabled {
1929                is_project_open_source,
1930            }
1931        } else {
1932            DataCollectionState::Disabled {
1933                is_project_open_source,
1934            }
1935        }
1936    }
1937
1938    fn toggle_data_collection(&mut self, cx: &mut App) {
1939        self.provider_data_collection.toggle(cx);
1940    }
1941
1942    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1943        self.zeta.read(cx).usage(cx)
1944    }
1945
1946    fn is_enabled(
1947        &self,
1948        _buffer: &Entity<Buffer>,
1949        _cursor_position: language::Anchor,
1950        _cx: &App,
1951    ) -> bool {
1952        true
1953    }
1954    fn is_refreshing(&self) -> bool {
1955        !self.pending_completions.is_empty()
1956    }
1957
1958    fn refresh(
1959        &mut self,
1960        project: Option<Entity<Project>>,
1961        buffer: Entity<Buffer>,
1962        position: language::Anchor,
1963        _debounce: bool,
1964        cx: &mut Context<Self>,
1965    ) {
1966        if self.zeta.read(cx).update_required {
1967            return;
1968        }
1969
1970        if self
1971            .zeta
1972            .read(cx)
1973            .user_store
1974            .read_with(cx, |user_store, _cx| {
1975                user_store.account_too_young() || user_store.has_overdue_invoices()
1976            })
1977        {
1978            return;
1979        }
1980
1981        if let Some(current_completion) = self.current_completion.as_ref() {
1982            let snapshot = buffer.read(cx).snapshot();
1983            if current_completion
1984                .completion
1985                .interpolate(&snapshot)
1986                .is_some()
1987            {
1988                return;
1989            }
1990        }
1991
1992        let pending_completion_id = self.next_pending_completion_id;
1993        self.next_pending_completion_id += 1;
1994        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1995        let last_request_timestamp = self.last_request_timestamp;
1996
1997        let task = cx.spawn(async move |this, cx| {
1998            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1999                .checked_duration_since(Instant::now())
2000            {
2001                cx.background_executor().timer(timeout).await;
2002            }
2003
2004            let completion_request = this.update(cx, |this, cx| {
2005                this.last_request_timestamp = Instant::now();
2006                this.zeta.update(cx, |zeta, cx| {
2007                    zeta.request_completion(project, &buffer, position, can_collect_data, cx)
2008                })
2009            });
2010
2011            let completion = match completion_request {
2012                Ok(completion_request) => {
2013                    let completion_request = completion_request.await;
2014                    completion_request.map(|c| {
2015                        c.map(|completion| CurrentEditPrediction {
2016                            buffer_id: buffer.entity_id(),
2017                            completion,
2018                        })
2019                    })
2020                }
2021                Err(error) => Err(error),
2022            };
2023            let Some(new_completion) = completion
2024                .context("edit prediction failed")
2025                .log_err()
2026                .flatten()
2027            else {
2028                this.update(cx, |this, cx| {
2029                    if this.pending_completions[0].id == pending_completion_id {
2030                        this.pending_completions.remove(0);
2031                    } else {
2032                        this.pending_completions.clear();
2033                    }
2034
2035                    cx.notify();
2036                })
2037                .ok();
2038                return;
2039            };
2040
2041            this.update(cx, |this, cx| {
2042                if this.pending_completions[0].id == pending_completion_id {
2043                    this.pending_completions.remove(0);
2044                } else {
2045                    this.pending_completions.clear();
2046                }
2047
2048                if let Some(old_completion) = this.current_completion.as_ref() {
2049                    let snapshot = buffer.read(cx).snapshot();
2050                    if new_completion.should_replace_completion(old_completion, &snapshot) {
2051                        this.zeta.update(cx, |zeta, cx| {
2052                            zeta.completion_shown(&new_completion.completion, cx);
2053                        });
2054                        this.current_completion = Some(new_completion);
2055                    }
2056                } else {
2057                    this.zeta.update(cx, |zeta, cx| {
2058                        zeta.completion_shown(&new_completion.completion, cx);
2059                    });
2060                    this.current_completion = Some(new_completion);
2061                }
2062
2063                cx.notify();
2064            })
2065            .ok();
2066        });
2067
2068        // We always maintain at most two pending completions. When we already
2069        // have two, we replace the newest one.
2070        if self.pending_completions.len() <= 1 {
2071            self.pending_completions.push(PendingCompletion {
2072                id: pending_completion_id,
2073                _task: task,
2074            });
2075        } else if self.pending_completions.len() == 2 {
2076            self.pending_completions.pop();
2077            self.pending_completions.push(PendingCompletion {
2078                id: pending_completion_id,
2079                _task: task,
2080            });
2081        }
2082    }
2083
2084    fn cycle(
2085        &mut self,
2086        _buffer: Entity<Buffer>,
2087        _cursor_position: language::Anchor,
2088        _direction: edit_prediction::Direction,
2089        _cx: &mut Context<Self>,
2090    ) {
2091        // Right now we don't support cycling.
2092    }
2093
2094    fn accept(&mut self, cx: &mut Context<Self>) {
2095        let completion_id = self
2096            .current_completion
2097            .as_ref()
2098            .map(|completion| completion.completion.id);
2099        if let Some(completion_id) = completion_id {
2100            self.zeta
2101                .update(cx, |zeta, cx| {
2102                    zeta.accept_edit_prediction(completion_id, cx)
2103                })
2104                .detach();
2105        }
2106        self.pending_completions.clear();
2107    }
2108
2109    fn discard(&mut self, _cx: &mut Context<Self>) {
2110        self.pending_completions.clear();
2111        self.current_completion.take();
2112    }
2113
2114    fn suggest(
2115        &mut self,
2116        buffer: &Entity<Buffer>,
2117        cursor_position: language::Anchor,
2118        cx: &mut Context<Self>,
2119    ) -> Option<edit_prediction::EditPrediction> {
2120        let CurrentEditPrediction {
2121            buffer_id,
2122            completion,
2123            ..
2124        } = self.current_completion.as_mut()?;
2125
2126        // Invalidate previous completion if it was generated for a different buffer.
2127        if *buffer_id != buffer.entity_id() {
2128            self.current_completion.take();
2129            return None;
2130        }
2131
2132        let buffer = buffer.read(cx);
2133        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
2134            self.current_completion.take();
2135            return None;
2136        };
2137
2138        let cursor_row = cursor_position.to_point(buffer).row;
2139        let (closest_edit_ix, (closest_edit_range, _)) =
2140            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
2141                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
2142                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
2143                cmp::min(distance_from_start, distance_from_end)
2144            })?;
2145
2146        let mut edit_start_ix = closest_edit_ix;
2147        for (range, _) in edits[..edit_start_ix].iter().rev() {
2148            let distance_from_closest_edit =
2149                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
2150            if distance_from_closest_edit <= 1 {
2151                edit_start_ix -= 1;
2152            } else {
2153                break;
2154            }
2155        }
2156
2157        let mut edit_end_ix = closest_edit_ix + 1;
2158        for (range, _) in &edits[edit_end_ix..] {
2159            let distance_from_closest_edit =
2160                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
2161            if distance_from_closest_edit <= 1 {
2162                edit_end_ix += 1;
2163            } else {
2164                break;
2165            }
2166        }
2167
2168        Some(edit_prediction::EditPrediction {
2169            id: Some(completion.id.to_string().into()),
2170            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
2171            edit_preview: Some(completion.edit_preview.clone()),
2172        })
2173    }
2174}
2175
2176fn tokens_for_bytes(bytes: usize) -> usize {
2177    /// Typical number of string bytes per token for the purposes of limiting model input. This is
2178    /// intentionally low to err on the side of underestimating limits.
2179    const BYTES_PER_TOKEN_GUESS: usize = 3;
2180    bytes / BYTES_PER_TOKEN_GUESS
2181}
2182
2183#[cfg(test)]
2184mod tests {
2185    use client::UserStore;
2186    use client::test::FakeServer;
2187    use clock::FakeSystemClock;
2188    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
2189    use gpui::TestAppContext;
2190    use http_client::FakeHttpClient;
2191    use indoc::indoc;
2192    use language::Point;
2193    use settings::SettingsStore;
2194
2195    use super::*;
2196
2197    #[gpui::test]
2198    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2199        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2200        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
2201            to_completion_edits(
2202                [(2..5, "REM".to_string()), (9..11, "".to_string())],
2203                &buffer,
2204                cx,
2205            )
2206            .into()
2207        });
2208
2209        let edit_preview = cx
2210            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2211            .await;
2212
2213        let completion = EditPrediction {
2214            edits,
2215            edit_preview,
2216            path: Path::new("").into(),
2217            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2218            id: EditPredictionId(Uuid::new_v4()),
2219            excerpt_range: 0..0,
2220            cursor_offset: 0,
2221            input_events: "".into(),
2222            input_excerpt: "".into(),
2223            output_excerpt: "".into(),
2224            buffer_snapshotted_at: Instant::now(),
2225            response_received_at: Instant::now(),
2226        };
2227
2228        cx.update(|cx| {
2229            assert_eq!(
2230                from_completion_edits(
2231                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2232                    &buffer,
2233                    cx
2234                ),
2235                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2236            );
2237
2238            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2239            assert_eq!(
2240                from_completion_edits(
2241                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2242                    &buffer,
2243                    cx
2244                ),
2245                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
2246            );
2247
2248            buffer.update(cx, |buffer, cx| buffer.undo(cx));
2249            assert_eq!(
2250                from_completion_edits(
2251                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2252                    &buffer,
2253                    cx
2254                ),
2255                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
2256            );
2257
2258            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2259            assert_eq!(
2260                from_completion_edits(
2261                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2262                    &buffer,
2263                    cx
2264                ),
2265                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
2266            );
2267
2268            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2269            assert_eq!(
2270                from_completion_edits(
2271                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2272                    &buffer,
2273                    cx
2274                ),
2275                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2276            );
2277
2278            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2279            assert_eq!(
2280                from_completion_edits(
2281                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2282                    &buffer,
2283                    cx
2284                ),
2285                vec![(9..11, "".to_string())]
2286            );
2287
2288            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2289            assert_eq!(
2290                from_completion_edits(
2291                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2292                    &buffer,
2293                    cx
2294                ),
2295                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2296            );
2297
2298            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2299            assert_eq!(
2300                from_completion_edits(
2301                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2302                    &buffer,
2303                    cx
2304                ),
2305                vec![(4..4, "M".to_string())]
2306            );
2307
2308            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2309            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2310        })
2311    }
2312
2313    #[gpui::test]
2314    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2315        cx.update(|cx| {
2316            let settings_store = SettingsStore::test(cx);
2317            cx.set_global(settings_store);
2318            client::init_settings(cx);
2319        });
2320
2321        let edits = edits_for_prediction(
2322            indoc! {"
2323                fn main() {
2324                    let word_1 = \"lorem\";
2325                    let range = word.len()..word.len();
2326                }
2327            "},
2328            indoc! {"
2329                <|editable_region_start|>
2330                fn main() {
2331                    let word_1 = \"lorem\";
2332                    let range = word_1.len()..word_1.len();
2333                }
2334
2335                <|editable_region_end|>
2336            "},
2337            cx,
2338        )
2339        .await;
2340        assert_eq!(
2341            edits,
2342            [
2343                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2344                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2345            ]
2346        );
2347
2348        let edits = edits_for_prediction(
2349            indoc! {"
2350                fn main() {
2351                    let story = \"the quick\"
2352                }
2353            "},
2354            indoc! {"
2355                <|editable_region_start|>
2356                fn main() {
2357                    let story = \"the quick brown fox jumps over the lazy dog\";
2358                }
2359
2360                <|editable_region_end|>
2361            "},
2362            cx,
2363        )
2364        .await;
2365        assert_eq!(
2366            edits,
2367            [
2368                (
2369                    Point::new(1, 26)..Point::new(1, 26),
2370                    " brown fox jumps over the lazy dog".to_string()
2371                ),
2372                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2373            ]
2374        );
2375    }
2376
2377    #[gpui::test]
2378    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2379        cx.update(|cx| {
2380            let settings_store = SettingsStore::test(cx);
2381            cx.set_global(settings_store);
2382            client::init_settings(cx);
2383        });
2384
2385        let buffer_content = "lorem\n";
2386        let completion_response = indoc! {"
2387            ```animals.js
2388            <|start_of_file|>
2389            <|editable_region_start|>
2390            lorem
2391            ipsum
2392            <|editable_region_end|>
2393            ```"};
2394
2395        let http_client = FakeHttpClient::create(move |req| async move {
2396            match (req.method(), req.uri().path()) {
2397                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2398                    .status(200)
2399                    .body(
2400                        serde_json::to_string(&CreateLlmTokenResponse {
2401                            token: LlmToken("the-llm-token".to_string()),
2402                        })
2403                        .unwrap()
2404                        .into(),
2405                    )
2406                    .unwrap()),
2407                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2408                    .status(200)
2409                    .body(
2410                        serde_json::to_string(&PredictEditsResponse {
2411                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2412                                .unwrap(),
2413                            output_excerpt: completion_response.to_string(),
2414                        })
2415                        .unwrap()
2416                        .into(),
2417                    )
2418                    .unwrap()),
2419                _ => Ok(http_client::Response::builder()
2420                    .status(404)
2421                    .body("Not Found".into())
2422                    .unwrap()),
2423            }
2424        });
2425
2426        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2427        cx.update(|cx| {
2428            RefreshLlmTokenListener::register(client.clone(), cx);
2429        });
2430        // Construct the fake server to authenticate.
2431        let _server = FakeServer::for_client(42, &client, cx).await;
2432        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2433        let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2434
2435        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2436        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2437        let completion_task = zeta.update(cx, |zeta, cx| {
2438            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2439        });
2440
2441        let completion = completion_task.await.unwrap().unwrap();
2442        buffer.update(cx, |buffer, cx| {
2443            buffer.edit(completion.edits.iter().cloned(), None, cx)
2444        });
2445        assert_eq!(
2446            buffer.read_with(cx, |buffer, _| buffer.text()),
2447            "lorem\nipsum"
2448        );
2449    }
2450
2451    async fn edits_for_prediction(
2452        buffer_content: &str,
2453        completion_response: &str,
2454        cx: &mut TestAppContext,
2455    ) -> Vec<(Range<Point>, String)> {
2456        let completion_response = completion_response.to_string();
2457        let http_client = FakeHttpClient::create(move |req| {
2458            let completion = completion_response.clone();
2459            async move {
2460                match (req.method(), req.uri().path()) {
2461                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2462                        .status(200)
2463                        .body(
2464                            serde_json::to_string(&CreateLlmTokenResponse {
2465                                token: LlmToken("the-llm-token".to_string()),
2466                            })
2467                            .unwrap()
2468                            .into(),
2469                        )
2470                        .unwrap()),
2471                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2472                        .status(200)
2473                        .body(
2474                            serde_json::to_string(&PredictEditsResponse {
2475                                request_id: Uuid::new_v4(),
2476                                output_excerpt: completion,
2477                            })
2478                            .unwrap()
2479                            .into(),
2480                        )
2481                        .unwrap()),
2482                    _ => Ok(http_client::Response::builder()
2483                        .status(404)
2484                        .body("Not Found".into())
2485                        .unwrap()),
2486                }
2487            }
2488        });
2489
2490        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2491        cx.update(|cx| {
2492            RefreshLlmTokenListener::register(client.clone(), cx);
2493        });
2494        // Construct the fake server to authenticate.
2495        let _server = FakeServer::for_client(42, &client, cx).await;
2496        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2497        let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
2498
2499        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2500        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2501        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2502        let completion_task = zeta.update(cx, |zeta, cx| {
2503            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2504        });
2505
2506        let completion = completion_task.await.unwrap().unwrap();
2507        completion
2508            .edits
2509            .iter()
2510            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2511            .collect::<Vec<_>>()
2512    }
2513
2514    fn to_completion_edits(
2515        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2516        buffer: &Entity<Buffer>,
2517        cx: &App,
2518    ) -> Vec<(Range<Anchor>, String)> {
2519        let buffer = buffer.read(cx);
2520        iterator
2521            .into_iter()
2522            .map(|(range, text)| {
2523                (
2524                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2525                    text,
2526                )
2527            })
2528            .collect()
2529    }
2530
2531    fn from_completion_edits(
2532        editor_edits: &[(Range<Anchor>, String)],
2533        buffer: &Entity<Buffer>,
2534        cx: &App,
2535    ) -> Vec<(Range<usize>, String)> {
2536        let buffer = buffer.read(cx);
2537        editor_edits
2538            .iter()
2539            .map(|(range, text)| {
2540                (
2541                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2542                    text.clone(),
2543                )
2544            })
2545            .collect()
2546    }
2547
2548    #[ctor::ctor]
2549    fn init_logger() {
2550        zlog::init_test();
2551    }
2552}