zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Maximum number of edit predictions to store for feedback.
  79const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  80
  81actions!(
  82    edit_prediction,
  83    [
  84        /// Clears the edit prediction history.
  85        ClearHistory
  86    ]
  87);
  88
  89#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  90pub struct EditPredictionId(Uuid);
  91
  92impl From<EditPredictionId> for gpui::ElementId {
  93    fn from(value: EditPredictionId) -> Self {
  94        gpui::ElementId::Uuid(value.0)
  95    }
  96}
  97
  98impl std::fmt::Display for EditPredictionId {
  99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 100        write!(f, "{}", self.0)
 101    }
 102}
 103
 104struct ZedPredictUpsell;
 105
 106impl Dismissable for ZedPredictUpsell {
 107    const KEY: &'static str = "dismissed-edit-predict-upsell";
 108
 109    fn dismissed() -> bool {
 110        // To make this backwards compatible with older versions of Zed, we
 111        // check if the user has seen the previous Edit Prediction Onboarding
 112        // before, by checking the data collection choice which was written to
 113        // the database once the user clicked on "Accept and Enable"
 114        if KEY_VALUE_STORE
 115            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 116            .log_err()
 117            .is_some_and(|s| s.is_some())
 118        {
 119            return true;
 120        }
 121
 122        KEY_VALUE_STORE
 123            .read_kvp(Self::KEY)
 124            .log_err()
 125            .is_some_and(|s| s.is_some())
 126    }
 127}
 128
 129pub fn should_show_upsell_modal() -> bool {
 130    !ZedPredictUpsell::dismissed()
 131}
 132
 133#[derive(Clone)]
 134struct ZetaGlobal(Entity<Zeta>);
 135
 136impl Global for ZetaGlobal {}
 137
 138#[derive(Clone)]
 139pub struct EditPrediction {
 140    id: EditPredictionId,
 141    path: Arc<Path>,
 142    excerpt_range: Range<usize>,
 143    cursor_offset: usize,
 144    edits: Arc<[(Range<Anchor>, String)]>,
 145    snapshot: BufferSnapshot,
 146    edit_preview: EditPreview,
 147    input_outline: Arc<str>,
 148    input_events: Arc<str>,
 149    input_excerpt: Arc<str>,
 150    output_excerpt: Arc<str>,
 151    buffer_snapshotted_at: Instant,
 152    response_received_at: Instant,
 153}
 154
 155impl EditPrediction {
 156    fn latency(&self) -> Duration {
 157        self.response_received_at
 158            .duration_since(self.buffer_snapshotted_at)
 159    }
 160
 161    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 162        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 163    }
 164}
 165
 166fn interpolate(
 167    old_snapshot: &BufferSnapshot,
 168    new_snapshot: &BufferSnapshot,
 169    current_edits: Arc<[(Range<Anchor>, String)]>,
 170) -> Option<Vec<(Range<Anchor>, String)>> {
 171    let mut edits = Vec::new();
 172
 173    let mut model_edits = current_edits.iter().peekable();
 174    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 175        while let Some((model_old_range, _)) = model_edits.peek() {
 176            let model_old_range = model_old_range.to_offset(old_snapshot);
 177            if model_old_range.end < user_edit.old.start {
 178                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 179                edits.push((model_old_range.clone(), model_new_text.clone()));
 180            } else {
 181                break;
 182            }
 183        }
 184
 185        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 186            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 187            if user_edit.old == model_old_offset_range {
 188                let user_new_text = new_snapshot
 189                    .text_for_range(user_edit.new.clone())
 190                    .collect::<String>();
 191
 192                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 193                    if !model_suffix.is_empty() {
 194                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 195                        edits.push((anchor..anchor, model_suffix.to_string()));
 196                    }
 197
 198                    model_edits.next();
 199                    continue;
 200                }
 201            }
 202        }
 203
 204        return None;
 205    }
 206
 207    edits.extend(model_edits.cloned());
 208
 209    if edits.is_empty() { None } else { Some(edits) }
 210}
 211
 212impl std::fmt::Debug for EditPrediction {
 213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 214        f.debug_struct("EditPrediction")
 215            .field("id", &self.id)
 216            .field("path", &self.path)
 217            .field("edits", &self.edits)
 218            .finish_non_exhaustive()
 219    }
 220}
 221
 222pub struct Zeta {
 223    workspace: WeakEntity<Workspace>,
 224    client: Arc<Client>,
 225    events: VecDeque<Event>,
 226    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 227    shown_completions: VecDeque<EditPrediction>,
 228    rated_completions: HashSet<EditPredictionId>,
 229    data_collection_choice: Entity<DataCollectionChoice>,
 230    llm_token: LlmApiToken,
 231    _llm_token_subscription: Subscription,
 232    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 233    update_required: bool,
 234    user_store: Entity<UserStore>,
 235    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 236    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 237}
 238
 239impl Zeta {
 240    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 241        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 242    }
 243
 244    pub fn register(
 245        workspace: Option<Entity<Workspace>>,
 246        worktree: Option<Entity<Worktree>>,
 247        client: Arc<Client>,
 248        user_store: Entity<UserStore>,
 249        cx: &mut App,
 250    ) -> Entity<Self> {
 251        let this = Self::global(cx).unwrap_or_else(|| {
 252            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 253            cx.set_global(ZetaGlobal(entity.clone()));
 254            entity
 255        });
 256
 257        this.update(cx, move |this, cx| {
 258            if let Some(worktree) = worktree {
 259                let worktree_id = worktree.read(cx).id();
 260                this.license_detection_watchers
 261                    .entry(worktree_id)
 262                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 263            }
 264        });
 265
 266        this
 267    }
 268
 269    pub fn clear_history(&mut self) {
 270        self.events.clear();
 271    }
 272
 273    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 274        self.user_store.read(cx).edit_prediction_usage()
 275    }
 276
 277    fn new(
 278        workspace: Option<Entity<Workspace>>,
 279        client: Arc<Client>,
 280        user_store: Entity<UserStore>,
 281        cx: &mut Context<Self>,
 282    ) -> Self {
 283        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 284
 285        let data_collection_choice = Self::load_data_collection_choices();
 286        let data_collection_choice = cx.new(|_| data_collection_choice);
 287
 288        if let Some(workspace) = &workspace {
 289            cx.subscribe(
 290                &workspace.read(cx).project().clone(),
 291                |this, _workspace, event, _cx| match event {
 292                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 293                        this.push_recent_project_entry(*project_entry_id)
 294                    }
 295                    _ => {}
 296                },
 297            )
 298            .detach();
 299        }
 300
 301        Self {
 302            workspace: workspace.map_or_else(
 303                || WeakEntity::new_invalid(),
 304                |workspace| workspace.downgrade(),
 305            ),
 306            client,
 307            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 308            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 309            rated_completions: HashSet::default(),
 310            registered_buffers: HashMap::default(),
 311            data_collection_choice,
 312            llm_token: LlmApiToken::default(),
 313            _llm_token_subscription: cx.subscribe(
 314                &refresh_llm_token_listener,
 315                |this, _listener, _event, cx| {
 316                    let client = this.client.clone();
 317                    let llm_token = this.llm_token.clone();
 318                    cx.spawn(async move |_this, _cx| {
 319                        llm_token.refresh(&client).await?;
 320                        anyhow::Ok(())
 321                    })
 322                    .detach_and_log_err(cx);
 323                },
 324            ),
 325            update_required: false,
 326            license_detection_watchers: HashMap::default(),
 327            user_store,
 328            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 329        }
 330    }
 331
 332    fn push_event(&mut self, event: Event) {
 333        if let Some(Event::BufferChange {
 334            new_snapshot: last_new_snapshot,
 335            timestamp: last_timestamp,
 336            ..
 337        }) = self.events.back_mut()
 338        {
 339            // Coalesce edits for the same buffer when they happen one after the other.
 340            let Event::BufferChange {
 341                old_snapshot,
 342                new_snapshot,
 343                timestamp,
 344            } = &event;
 345
 346            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 347                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 348                && old_snapshot.version == last_new_snapshot.version
 349            {
 350                *last_new_snapshot = new_snapshot.clone();
 351                *last_timestamp = *timestamp;
 352                return;
 353            }
 354        }
 355
 356        if self.events.len() >= MAX_EVENT_COUNT {
 357            // These are halved instead of popping to improve prompt caching.
 358            self.events.drain(..MAX_EVENT_COUNT / 2);
 359        }
 360
 361        self.events.push_back(event);
 362    }
 363
 364    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 365        let buffer_id = buffer.entity_id();
 366        let weak_buffer = buffer.downgrade();
 367
 368        if let std::collections::hash_map::Entry::Vacant(entry) =
 369            self.registered_buffers.entry(buffer_id)
 370        {
 371            let snapshot = buffer.read(cx).snapshot();
 372
 373            entry.insert(RegisteredBuffer {
 374                snapshot,
 375                _subscriptions: [
 376                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 377                        this.handle_buffer_event(buffer, event, cx);
 378                    }),
 379                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 380                        this.registered_buffers.remove(&weak_buffer.entity_id());
 381                    }),
 382                ],
 383            });
 384        };
 385    }
 386
 387    fn handle_buffer_event(
 388        &mut self,
 389        buffer: Entity<Buffer>,
 390        event: &language::BufferEvent,
 391        cx: &mut Context<Self>,
 392    ) {
 393        if let language::BufferEvent::Edited = event {
 394            self.report_changes_for_buffer(&buffer, cx);
 395        }
 396    }
 397
 398    fn request_completion_impl<F, R>(
 399        &mut self,
 400        workspace: Option<Entity<Workspace>>,
 401        project: Option<&Entity<Project>>,
 402        buffer: &Entity<Buffer>,
 403        cursor: language::Anchor,
 404        can_collect_data: bool,
 405        cx: &mut Context<Self>,
 406        perform_predict_edits: F,
 407    ) -> Task<Result<Option<EditPrediction>>>
 408    where
 409        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 410        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 411            + Send
 412            + 'static,
 413    {
 414        let buffer = buffer.clone();
 415        let buffer_snapshotted_at = Instant::now();
 416        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 417        let zeta = cx.entity();
 418        let events = self.events.clone();
 419        let client = self.client.clone();
 420        let llm_token = self.llm_token.clone();
 421        let app_version = AppVersion::global(cx);
 422
 423        let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
 424            (can_collect_data, project, snapshot.file())
 425            && let Some(repository) =
 426                git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 427        {
 428            let repository = repository.read(cx);
 429            let git_info = make_predict_edits_git_info(repository);
 430            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
 431            (git_info, Some(recent_files))
 432        } else {
 433            (None, None)
 434        };
 435
 436        let full_path: Arc<Path> = snapshot
 437            .file()
 438            .map(|f| Arc::from(f.full_path(cx).as_path()))
 439            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 440        let full_path_str = full_path.to_string_lossy().to_string();
 441        let cursor_point = cursor.to_point(&snapshot);
 442        let cursor_offset = cursor_point.to_offset(&snapshot);
 443        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 444        let gather_task = gather_context(
 445            project,
 446            full_path_str,
 447            &snapshot,
 448            cursor_point,
 449            make_events_prompt,
 450            can_collect_data,
 451            git_info,
 452            recent_files,
 453            cx,
 454        );
 455
 456        cx.spawn(async move |this, cx| {
 457            let GatherContextOutput {
 458                body,
 459                editable_range,
 460            } = gather_task.await?;
 461            let done_gathering_context_at = Instant::now();
 462
 463            log::debug!(
 464                "Events:\n{}\nExcerpt:\n{:?}",
 465                body.input_events,
 466                body.input_excerpt
 467            );
 468
 469            let input_outline = body.outline.clone().unwrap_or_default();
 470            let input_events = body.input_events.clone();
 471            let input_excerpt = body.input_excerpt.clone();
 472
 473            let response = perform_predict_edits(PerformPredictEditsParams {
 474                client,
 475                llm_token,
 476                app_version,
 477                body,
 478            })
 479            .await;
 480            let (response, usage) = match response {
 481                Ok(response) => response,
 482                Err(err) => {
 483                    if err.is::<ZedUpdateRequiredError>() {
 484                        cx.update(|cx| {
 485                            zeta.update(cx, |zeta, _cx| {
 486                                zeta.update_required = true;
 487                            });
 488
 489                            if let Some(workspace) = workspace {
 490                                workspace.update(cx, |workspace, cx| {
 491                                    workspace.show_notification(
 492                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 493                                        cx,
 494                                        |cx| {
 495                                            cx.new(|cx| {
 496                                                ErrorMessagePrompt::new(err.to_string(), cx)
 497                                                    .with_link_button(
 498                                                        "Update Zed",
 499                                                        "https://zed.dev/releases",
 500                                                    )
 501                                            })
 502                                        },
 503                                    );
 504                                });
 505                            }
 506                        })
 507                        .ok();
 508                    }
 509
 510                    return Err(err);
 511                }
 512            };
 513
 514            let received_response_at = Instant::now();
 515            log::debug!("completion response: {}", &response.output_excerpt);
 516
 517            if let Some(usage) = usage {
 518                this.update(cx, |this, cx| {
 519                    this.user_store.update(cx, |user_store, cx| {
 520                        user_store.update_edit_prediction_usage(usage, cx);
 521                    });
 522                })
 523                .ok();
 524            }
 525
 526            let edit_prediction = Self::process_completion_response(
 527                response,
 528                buffer,
 529                &snapshot,
 530                editable_range,
 531                cursor_offset,
 532                full_path,
 533                input_outline,
 534                input_events,
 535                input_excerpt,
 536                buffer_snapshotted_at,
 537                cx,
 538            )
 539            .await;
 540
 541            let finished_at = Instant::now();
 542
 543            // record latency for ~1% of requests
 544            if rand::random::<u8>() <= 2 {
 545                telemetry::event!(
 546                    "Edit Prediction Request",
 547                    context_latency = done_gathering_context_at
 548                        .duration_since(buffer_snapshotted_at)
 549                        .as_millis(),
 550                    request_latency = received_response_at
 551                        .duration_since(done_gathering_context_at)
 552                        .as_millis(),
 553                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 554                );
 555            }
 556
 557            edit_prediction
 558        })
 559    }
 560
 561    // Generates several example completions of various states to fill the Zeta completion modal
 562    #[cfg(any(test, feature = "test-support"))]
 563    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 564        use language::Point;
 565
 566        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 567            And maybe a short line
 568
 569            Then a few lines
 570
 571            and then another
 572            "#};
 573
 574        let project = None;
 575        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 576        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 577
 578        let completion_tasks = vec![
 579            self.fake_completion(
 580                project,
 581                &buffer,
 582                position,
 583                PredictEditsResponse {
 584                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 585                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 586a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 587[here's an edit]
 588And maybe a short line
 589Then a few lines
 590and then another
 591{EDITABLE_REGION_END_MARKER}
 592                        ", ),
 593                },
 594                cx,
 595            ),
 596            self.fake_completion(
 597                project,
 598                &buffer,
 599                position,
 600                PredictEditsResponse {
 601                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 602                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 603a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 604And maybe a short line
 605[and another edit]
 606Then a few lines
 607and then another
 608{EDITABLE_REGION_END_MARKER}
 609                        "#),
 610                },
 611                cx,
 612            ),
 613            self.fake_completion(
 614                project,
 615                &buffer,
 616                position,
 617                PredictEditsResponse {
 618                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 619                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 620a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 621And maybe a short line
 622
 623Then a few lines
 624
 625and then another
 626{EDITABLE_REGION_END_MARKER}
 627                        "#),
 628                },
 629                cx,
 630            ),
 631            self.fake_completion(
 632                project,
 633                &buffer,
 634                position,
 635                PredictEditsResponse {
 636                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 637                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 638a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 639And maybe a short line
 640
 641Then a few lines
 642
 643and then another
 644{EDITABLE_REGION_END_MARKER}
 645                        "#),
 646                },
 647                cx,
 648            ),
 649            self.fake_completion(
 650                project,
 651                &buffer,
 652                position,
 653                PredictEditsResponse {
 654                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 655                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 656a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 657And maybe a short line
 658Then a few lines
 659[a third completion]
 660and then another
 661{EDITABLE_REGION_END_MARKER}
 662                        "#),
 663                },
 664                cx,
 665            ),
 666            self.fake_completion(
 667                project,
 668                &buffer,
 669                position,
 670                PredictEditsResponse {
 671                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 672                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 673a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 674And maybe a short line
 675and then another
 676[fourth completion example]
 677{EDITABLE_REGION_END_MARKER}
 678                        "#),
 679                },
 680                cx,
 681            ),
 682            self.fake_completion(
 683                project,
 684                &buffer,
 685                position,
 686                PredictEditsResponse {
 687                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 688                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 689a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 690And maybe a short line
 691Then a few lines
 692and then another
 693[fifth and final completion]
 694{EDITABLE_REGION_END_MARKER}
 695                        "#),
 696                },
 697                cx,
 698            ),
 699        ];
 700
 701        cx.spawn(async move |zeta, cx| {
 702            for task in completion_tasks {
 703                task.await.unwrap();
 704            }
 705
 706            zeta.update(cx, |zeta, _cx| {
 707                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 708                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 709            })
 710            .ok();
 711        })
 712    }
 713
 714    #[cfg(any(test, feature = "test-support"))]
 715    pub fn fake_completion(
 716        &mut self,
 717        project: Option<&Entity<Project>>,
 718        buffer: &Entity<Buffer>,
 719        position: language::Anchor,
 720        response: PredictEditsResponse,
 721        cx: &mut Context<Self>,
 722    ) -> Task<Result<Option<EditPrediction>>> {
 723        use std::future::ready;
 724
 725        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
 726            ready(Ok((response, None)))
 727        })
 728    }
 729
 730    pub fn request_completion(
 731        &mut self,
 732        project: Option<&Entity<Project>>,
 733        buffer: &Entity<Buffer>,
 734        position: language::Anchor,
 735        can_collect_data: bool,
 736        cx: &mut Context<Self>,
 737    ) -> Task<Result<Option<EditPrediction>>> {
 738        self.request_completion_impl(
 739            self.workspace.upgrade(),
 740            project,
 741            buffer,
 742            position,
 743            can_collect_data,
 744            cx,
 745            Self::perform_predict_edits,
 746        )
 747    }
 748
 749    pub fn perform_predict_edits(
 750        params: PerformPredictEditsParams,
 751    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 752        async move {
 753            let PerformPredictEditsParams {
 754                client,
 755                llm_token,
 756                app_version,
 757                body,
 758                ..
 759            } = params;
 760
 761            let http_client = client.http_client();
 762            let mut token = llm_token.acquire(&client).await?;
 763            let mut did_retry = false;
 764
 765            loop {
 766                let request_builder = http_client::Request::builder().method(Method::POST);
 767                let request_builder =
 768                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 769                        request_builder.uri(predict_edits_url)
 770                    } else {
 771                        request_builder.uri(
 772                            http_client
 773                                .build_zed_llm_url("/predict_edits/v2", &[])?
 774                                .as_ref(),
 775                        )
 776                    };
 777                let request = request_builder
 778                    .header("Content-Type", "application/json")
 779                    .header("Authorization", format!("Bearer {}", token))
 780                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 781                    .body(serde_json::to_string(&body)?.into())?;
 782
 783                let mut response = http_client.send(request).await?;
 784
 785                if let Some(minimum_required_version) = response
 786                    .headers()
 787                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 788                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 789                {
 790                    anyhow::ensure!(
 791                        app_version >= minimum_required_version,
 792                        ZedUpdateRequiredError {
 793                            minimum_version: minimum_required_version
 794                        }
 795                    );
 796                }
 797
 798                if response.status().is_success() {
 799                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 800
 801                    let mut body = String::new();
 802                    response.body_mut().read_to_string(&mut body).await?;
 803                    return Ok((serde_json::from_str(&body)?, usage));
 804                } else if !did_retry
 805                    && response
 806                        .headers()
 807                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 808                        .is_some()
 809                {
 810                    did_retry = true;
 811                    token = llm_token.refresh(&client).await?;
 812                } else {
 813                    let mut body = String::new();
 814                    response.body_mut().read_to_string(&mut body).await?;
 815                    anyhow::bail!(
 816                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 817                        response.status(),
 818                        body
 819                    );
 820                }
 821            }
 822        }
 823    }
 824
 825    fn accept_edit_prediction(
 826        &mut self,
 827        request_id: EditPredictionId,
 828        cx: &mut Context<Self>,
 829    ) -> Task<Result<()>> {
 830        let client = self.client.clone();
 831        let llm_token = self.llm_token.clone();
 832        let app_version = AppVersion::global(cx);
 833        cx.spawn(async move |this, cx| {
 834            let http_client = client.http_client();
 835            let mut response = llm_token_retry(&llm_token, &client, |token| {
 836                let request_builder = http_client::Request::builder().method(Method::POST);
 837                let request_builder =
 838                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 839                        request_builder.uri(accept_prediction_url)
 840                    } else {
 841                        request_builder.uri(
 842                            http_client
 843                                .build_zed_llm_url("/predict_edits/accept", &[])?
 844                                .as_ref(),
 845                        )
 846                    };
 847                Ok(request_builder
 848                    .header("Content-Type", "application/json")
 849                    .header("Authorization", format!("Bearer {}", token))
 850                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 851                    .body(
 852                        serde_json::to_string(&AcceptEditPredictionBody {
 853                            request_id: request_id.0,
 854                        })?
 855                        .into(),
 856                    )?)
 857            })
 858            .await?;
 859
 860            if let Some(minimum_required_version) = response
 861                .headers()
 862                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 863                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 864                && app_version < minimum_required_version
 865            {
 866                return Err(anyhow!(ZedUpdateRequiredError {
 867                    minimum_version: minimum_required_version
 868                }));
 869            }
 870
 871            if response.status().is_success() {
 872                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 873                    this.update(cx, |this, cx| {
 874                        this.user_store.update(cx, |user_store, cx| {
 875                            user_store.update_edit_prediction_usage(usage, cx);
 876                        });
 877                    })?;
 878                }
 879
 880                Ok(())
 881            } else {
 882                let mut body = String::new();
 883                response.body_mut().read_to_string(&mut body).await?;
 884                Err(anyhow!(
 885                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 886                    response.status(),
 887                    body
 888                ))
 889            }
 890        })
 891    }
 892
 893    fn process_completion_response(
 894        prediction_response: PredictEditsResponse,
 895        buffer: Entity<Buffer>,
 896        snapshot: &BufferSnapshot,
 897        editable_range: Range<usize>,
 898        cursor_offset: usize,
 899        path: Arc<Path>,
 900        input_outline: String,
 901        input_events: String,
 902        input_excerpt: String,
 903        buffer_snapshotted_at: Instant,
 904        cx: &AsyncApp,
 905    ) -> Task<Result<Option<EditPrediction>>> {
 906        let snapshot = snapshot.clone();
 907        let request_id = prediction_response.request_id;
 908        let output_excerpt = prediction_response.output_excerpt;
 909        cx.spawn(async move |cx| {
 910            let output_excerpt: Arc<str> = output_excerpt.into();
 911
 912            let edits: Arc<[(Range<Anchor>, String)]> = cx
 913                .background_spawn({
 914                    let output_excerpt = output_excerpt.clone();
 915                    let editable_range = editable_range.clone();
 916                    let snapshot = snapshot.clone();
 917                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 918                })
 919                .await?
 920                .into();
 921
 922            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 923                let edits = edits.clone();
 924                |buffer, cx| {
 925                    let new_snapshot = buffer.snapshot();
 926                    let edits: Arc<[(Range<Anchor>, String)]> =
 927                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 928                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 929                }
 930            })?
 931            else {
 932                return anyhow::Ok(None);
 933            };
 934
 935            let edit_preview = edit_preview.await;
 936
 937            Ok(Some(EditPrediction {
 938                id: EditPredictionId(request_id),
 939                path,
 940                excerpt_range: editable_range,
 941                cursor_offset,
 942                edits,
 943                edit_preview,
 944                snapshot,
 945                input_outline: input_outline.into(),
 946                input_events: input_events.into(),
 947                input_excerpt: input_excerpt.into(),
 948                output_excerpt,
 949                buffer_snapshotted_at,
 950                response_received_at: Instant::now(),
 951            }))
 952        })
 953    }
 954
 955    fn parse_edits(
 956        output_excerpt: Arc<str>,
 957        editable_range: Range<usize>,
 958        snapshot: &BufferSnapshot,
 959    ) -> Result<Vec<(Range<Anchor>, String)>> {
 960        let content = output_excerpt.replace(CURSOR_MARKER, "");
 961
 962        let start_markers = content
 963            .match_indices(EDITABLE_REGION_START_MARKER)
 964            .collect::<Vec<_>>();
 965        anyhow::ensure!(
 966            start_markers.len() == 1,
 967            "expected exactly one start marker, found {}",
 968            start_markers.len()
 969        );
 970
 971        let end_markers = content
 972            .match_indices(EDITABLE_REGION_END_MARKER)
 973            .collect::<Vec<_>>();
 974        anyhow::ensure!(
 975            end_markers.len() == 1,
 976            "expected exactly one end marker, found {}",
 977            end_markers.len()
 978        );
 979
 980        let sof_markers = content
 981            .match_indices(START_OF_FILE_MARKER)
 982            .collect::<Vec<_>>();
 983        anyhow::ensure!(
 984            sof_markers.len() <= 1,
 985            "expected at most one start-of-file marker, found {}",
 986            sof_markers.len()
 987        );
 988
 989        let codefence_start = start_markers[0].0;
 990        let content = &content[codefence_start..];
 991
 992        let newline_ix = content.find('\n').context("could not find newline")?;
 993        let content = &content[newline_ix + 1..];
 994
 995        let codefence_end = content
 996            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 997            .context("could not find end marker")?;
 998        let new_text = &content[..codefence_end];
 999
1000        let old_text = snapshot
1001            .text_for_range(editable_range.clone())
1002            .collect::<String>();
1003
1004        Ok(Self::compute_edits(
1005            old_text,
1006            new_text,
1007            editable_range.start,
1008            snapshot,
1009        ))
1010    }
1011
1012    pub fn compute_edits(
1013        old_text: String,
1014        new_text: &str,
1015        offset: usize,
1016        snapshot: &BufferSnapshot,
1017    ) -> Vec<(Range<Anchor>, String)> {
1018        text_diff(&old_text, new_text)
1019            .into_iter()
1020            .map(|(mut old_range, new_text)| {
1021                old_range.start += offset;
1022                old_range.end += offset;
1023
1024                let prefix_len = common_prefix(
1025                    snapshot.chars_for_range(old_range.clone()),
1026                    new_text.chars(),
1027                );
1028                old_range.start += prefix_len;
1029
1030                let suffix_len = common_prefix(
1031                    snapshot.reversed_chars_for_range(old_range.clone()),
1032                    new_text[prefix_len..].chars().rev(),
1033                );
1034                old_range.end = old_range.end.saturating_sub(suffix_len);
1035
1036                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1037                let range = if old_range.is_empty() {
1038                    let anchor = snapshot.anchor_after(old_range.start);
1039                    anchor..anchor
1040                } else {
1041                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1042                };
1043                (range, new_text)
1044            })
1045            .collect()
1046    }
1047
1048    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1049        self.rated_completions.contains(&completion_id)
1050    }
1051
1052    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1053        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1054            let completion = self.shown_completions.pop_back().unwrap();
1055            self.rated_completions.remove(&completion.id);
1056        }
1057        self.shown_completions.push_front(completion.clone());
1058        cx.notify();
1059    }
1060
1061    pub fn rate_completion(
1062        &mut self,
1063        completion: &EditPrediction,
1064        rating: EditPredictionRating,
1065        feedback: String,
1066        cx: &mut Context<Self>,
1067    ) {
1068        self.rated_completions.insert(completion.id);
1069        telemetry::event!(
1070            "Edit Prediction Rated",
1071            rating,
1072            input_events = completion.input_events,
1073            input_excerpt = completion.input_excerpt,
1074            input_outline = completion.input_outline,
1075            output_excerpt = completion.output_excerpt,
1076            feedback
1077        );
1078        self.client.telemetry().flush_events().detach();
1079        cx.notify();
1080    }
1081
1082    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1083        self.shown_completions.iter()
1084    }
1085
1086    pub fn shown_completions_len(&self) -> usize {
1087        self.shown_completions.len()
1088    }
1089
1090    fn report_changes_for_buffer(
1091        &mut self,
1092        buffer: &Entity<Buffer>,
1093        cx: &mut Context<Self>,
1094    ) -> BufferSnapshot {
1095        self.register_buffer(buffer, cx);
1096
1097        let registered_buffer = self
1098            .registered_buffers
1099            .get_mut(&buffer.entity_id())
1100            .unwrap();
1101        let new_snapshot = buffer.read(cx).snapshot();
1102
1103        if new_snapshot.version != registered_buffer.snapshot.version {
1104            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1105            self.push_event(Event::BufferChange {
1106                old_snapshot,
1107                new_snapshot: new_snapshot.clone(),
1108                timestamp: Instant::now(),
1109            });
1110        }
1111
1112        new_snapshot
1113    }
1114
1115    fn load_data_collection_choices() -> DataCollectionChoice {
1116        let choice = KEY_VALUE_STORE
1117            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1118            .log_err()
1119            .flatten();
1120
1121        match choice.as_deref() {
1122            Some("true") => DataCollectionChoice::Enabled,
1123            Some("false") => DataCollectionChoice::Disabled,
1124            Some(_) => {
1125                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1126                DataCollectionChoice::NotAnswered
1127            }
1128            None => DataCollectionChoice::NotAnswered,
1129        }
1130    }
1131
1132    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1133        let now = Instant::now();
1134        if let Some(existing_ix) = self
1135            .recent_project_entries
1136            .iter()
1137            .rposition(|(id, _)| *id == project_entry_id)
1138        {
1139            self.recent_project_entries.remove(existing_ix);
1140        }
1141        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1142            self.recent_project_entries.pop_front();
1143        }
1144        self.recent_project_entries
1145            .push_back((project_entry_id, now));
1146    }
1147
1148    fn recent_files(
1149        &mut self,
1150        now: &Instant,
1151        repository: &Repository,
1152        cx: &Context<Self>,
1153    ) -> Vec<PredictEditsRecentFile> {
1154        let Ok(project) = self
1155            .workspace
1156            .read_with(cx, |workspace, _cx| workspace.project().clone())
1157        else {
1158            return Vec::new();
1159        };
1160        let mut results = Vec::new();
1161        for ix in (0..self.recent_project_entries.len()).rev() {
1162            let (id, last_active_at) = &self.recent_project_entries[ix];
1163            let Some(project_path) = project.read(cx).path_for_entry(*id, cx) else {
1164                self.recent_project_entries.remove(ix);
1165                continue;
1166            };
1167            let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx) else {
1168                // entry not removed since queries involving other repositories might occur later
1169                continue;
1170            };
1171            let Some(repo_path) = repo_path.to_str() else {
1172                // paths may not be valid UTF-8
1173                self.recent_project_entries.remove(ix);
1174                continue;
1175            };
1176            let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into()
1177            else {
1178                self.recent_project_entries.remove(ix);
1179                continue;
1180            };
1181            results.push(PredictEditsRecentFile {
1182                repo_path: repo_path.to_string(),
1183                active_to_now_ms,
1184            });
1185        }
1186        results
1187    }
1188}
1189
1190pub struct PerformPredictEditsParams {
1191    pub client: Arc<Client>,
1192    pub llm_token: LlmApiToken,
1193    pub app_version: SemanticVersion,
1194    pub body: PredictEditsBody,
1195}
1196
1197#[derive(Error, Debug)]
1198#[error(
1199    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1200)]
1201pub struct ZedUpdateRequiredError {
1202    minimum_version: SemanticVersion,
1203}
1204
1205fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1206    a.zip(b)
1207        .take_while(|(a, b)| a == b)
1208        .map(|(a, _)| a.len_utf8())
1209        .sum()
1210}
1211
1212fn git_repository_for_file(
1213    project: &Entity<Project>,
1214    project_path: &ProjectPath,
1215    cx: &App,
1216) -> Option<Entity<Repository>> {
1217    let git_store = project.read(cx).git_store().read(cx);
1218    git_store
1219        .repository_and_path_for_project_path(project_path, cx)
1220        .map(|(repo, _repo_path)| repo)
1221}
1222
1223fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
1224    let head_sha = repository
1225        .head_commit
1226        .as_ref()
1227        .map(|head_commit| head_commit.sha.to_string());
1228    let remote_origin_url = repository.remote_origin_url.clone();
1229    let remote_upstream_url = repository.remote_upstream_url.clone();
1230    if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1231        return None;
1232    }
1233    Some(PredictEditsGitInfo {
1234        head_sha,
1235        remote_origin_url,
1236        remote_upstream_url,
1237    })
1238}
1239
1240pub struct GatherContextOutput {
1241    pub body: PredictEditsBody,
1242    pub editable_range: Range<usize>,
1243}
1244
1245pub fn gather_context(
1246    project: Option<&Entity<Project>>,
1247    full_path_str: String,
1248    snapshot: &BufferSnapshot,
1249    cursor_point: language::Point,
1250    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1251    can_collect_data: bool,
1252    git_info: Option<PredictEditsGitInfo>,
1253    recent_files: Option<Vec<PredictEditsRecentFile>>,
1254    cx: &App,
1255) -> Task<Result<GatherContextOutput>> {
1256    let local_lsp_store =
1257        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1258    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1259        if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1260            snapshot
1261                .diagnostic_groups(None)
1262                .into_iter()
1263                .filter_map(|(language_server_id, diagnostic_group)| {
1264                    let language_server =
1265                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1266                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1267                    let language_server_name = language_server.name().to_string();
1268                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1269                    Some((language_server_name, serialized))
1270                })
1271                .collect::<Vec<_>>()
1272        } else {
1273            Vec::new()
1274        };
1275
1276    cx.background_spawn({
1277        let snapshot = snapshot.clone();
1278        async move {
1279            let diagnostic_groups = if diagnostic_groups.is_empty()
1280                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1281            {
1282                None
1283            } else {
1284                Some(diagnostic_groups)
1285            };
1286
1287            let input_excerpt = excerpt_for_cursor_position(
1288                cursor_point,
1289                &full_path_str,
1290                &snapshot,
1291                MAX_REWRITE_TOKENS,
1292                MAX_CONTEXT_TOKENS,
1293            );
1294            let input_events = make_events_prompt();
1295            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1296
1297            let body = PredictEditsBody {
1298                input_events,
1299                input_excerpt: input_excerpt.prompt,
1300                can_collect_data,
1301                diagnostic_groups,
1302                git_info,
1303                outline: None,
1304                speculated_output: None,
1305                recent_files,
1306            };
1307
1308            Ok(GatherContextOutput {
1309                body,
1310                editable_range,
1311            })
1312        }
1313    })
1314}
1315
1316fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1317    let mut result = String::new();
1318    for event in events.iter().rev() {
1319        let event_string = event.to_prompt();
1320        let event_tokens = tokens_for_bytes(event_string.len());
1321        if event_tokens > remaining_tokens {
1322            break;
1323        }
1324
1325        if !result.is_empty() {
1326            result.insert_str(0, "\n\n");
1327        }
1328        result.insert_str(0, &event_string);
1329        remaining_tokens -= event_tokens;
1330    }
1331    result
1332}
1333
1334struct RegisteredBuffer {
1335    snapshot: BufferSnapshot,
1336    _subscriptions: [gpui::Subscription; 2],
1337}
1338
1339#[derive(Clone)]
1340pub enum Event {
1341    BufferChange {
1342        old_snapshot: BufferSnapshot,
1343        new_snapshot: BufferSnapshot,
1344        timestamp: Instant,
1345    },
1346}
1347
1348impl Event {
1349    fn to_prompt(&self) -> String {
1350        match self {
1351            Event::BufferChange {
1352                old_snapshot,
1353                new_snapshot,
1354                ..
1355            } => {
1356                let mut prompt = String::new();
1357
1358                let old_path = old_snapshot
1359                    .file()
1360                    .map(|f| f.path().as_ref())
1361                    .unwrap_or(Path::new("untitled"));
1362                let new_path = new_snapshot
1363                    .file()
1364                    .map(|f| f.path().as_ref())
1365                    .unwrap_or(Path::new("untitled"));
1366                if old_path != new_path {
1367                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1368                }
1369
1370                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1371                if !diff.is_empty() {
1372                    write!(
1373                        prompt,
1374                        "User edited {:?}:\n```diff\n{}\n```",
1375                        new_path, diff
1376                    )
1377                    .unwrap();
1378                }
1379
1380                prompt
1381            }
1382        }
1383    }
1384}
1385
1386#[derive(Debug, Clone)]
1387struct CurrentEditPrediction {
1388    buffer_id: EntityId,
1389    completion: EditPrediction,
1390}
1391
1392impl CurrentEditPrediction {
1393    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1394        if self.buffer_id != old_completion.buffer_id {
1395            return true;
1396        }
1397
1398        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1399            return true;
1400        };
1401        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1402            return false;
1403        };
1404
1405        if old_edits.len() == 1 && new_edits.len() == 1 {
1406            let (old_range, old_text) = &old_edits[0];
1407            let (new_range, new_text) = &new_edits[0];
1408            new_range == old_range && new_text.starts_with(old_text)
1409        } else {
1410            true
1411        }
1412    }
1413}
1414
1415struct PendingCompletion {
1416    id: usize,
1417    _task: Task<()>,
1418}
1419
1420#[derive(Debug, Clone, Copy)]
1421pub enum DataCollectionChoice {
1422    NotAnswered,
1423    Enabled,
1424    Disabled,
1425}
1426
1427impl DataCollectionChoice {
1428    pub fn is_enabled(self) -> bool {
1429        match self {
1430            Self::Enabled => true,
1431            Self::NotAnswered | Self::Disabled => false,
1432        }
1433    }
1434
1435    pub fn is_answered(self) -> bool {
1436        match self {
1437            Self::Enabled | Self::Disabled => true,
1438            Self::NotAnswered => false,
1439        }
1440    }
1441
1442    pub fn toggle(&self) -> DataCollectionChoice {
1443        match self {
1444            Self::Enabled => Self::Disabled,
1445            Self::Disabled => Self::Enabled,
1446            Self::NotAnswered => Self::Enabled,
1447        }
1448    }
1449}
1450
1451impl From<bool> for DataCollectionChoice {
1452    fn from(value: bool) -> Self {
1453        match value {
1454            true => DataCollectionChoice::Enabled,
1455            false => DataCollectionChoice::Disabled,
1456        }
1457    }
1458}
1459
1460pub struct ProviderDataCollection {
1461    /// When set to None, data collection is not possible in the provider buffer
1462    choice: Option<Entity<DataCollectionChoice>>,
1463    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1464}
1465
1466impl ProviderDataCollection {
1467    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1468        let choice_and_watcher = buffer.and_then(|buffer| {
1469            let file = buffer.read(cx).file()?;
1470
1471            if !file.is_local() || file.is_private() {
1472                return None;
1473            }
1474
1475            let zeta = zeta.read(cx);
1476            let choice = zeta.data_collection_choice.clone();
1477
1478            let license_detection_watcher = zeta
1479                .license_detection_watchers
1480                .get(&file.worktree_id(cx))
1481                .cloned()?;
1482
1483            Some((choice, license_detection_watcher))
1484        });
1485
1486        if let Some((choice, watcher)) = choice_and_watcher {
1487            ProviderDataCollection {
1488                choice: Some(choice),
1489                license_detection_watcher: Some(watcher),
1490            }
1491        } else {
1492            ProviderDataCollection {
1493                choice: None,
1494                license_detection_watcher: None,
1495            }
1496        }
1497    }
1498
1499    pub fn can_collect_data(&self, cx: &App) -> bool {
1500        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1501    }
1502
1503    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1504        self.choice
1505            .as_ref()
1506            .is_some_and(|choice| choice.read(cx).is_enabled())
1507    }
1508
1509    fn is_project_open_source(&self) -> bool {
1510        self.license_detection_watcher
1511            .as_ref()
1512            .is_some_and(|watcher| watcher.is_project_open_source())
1513    }
1514
1515    pub fn toggle(&mut self, cx: &mut App) {
1516        if let Some(choice) = self.choice.as_mut() {
1517            let new_choice = choice.update(cx, |choice, _cx| {
1518                let new_choice = choice.toggle();
1519                *choice = new_choice;
1520                new_choice
1521            });
1522
1523            db::write_and_log(cx, move || {
1524                KEY_VALUE_STORE.write_kvp(
1525                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1526                    new_choice.is_enabled().to_string(),
1527                )
1528            });
1529        }
1530    }
1531}
1532
1533async fn llm_token_retry(
1534    llm_token: &LlmApiToken,
1535    client: &Arc<Client>,
1536    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1537) -> Result<Response<AsyncBody>> {
1538    let mut did_retry = false;
1539    let http_client = client.http_client();
1540    let mut token = llm_token.acquire(client).await?;
1541    loop {
1542        let request = build_request(token.clone())?;
1543        let response = http_client.send(request).await?;
1544
1545        if !did_retry
1546            && !response.status().is_success()
1547            && response
1548                .headers()
1549                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1550                .is_some()
1551        {
1552            did_retry = true;
1553            token = llm_token.refresh(client).await?;
1554            continue;
1555        }
1556
1557        return Ok(response);
1558    }
1559}
1560
1561pub struct ZetaEditPredictionProvider {
1562    zeta: Entity<Zeta>,
1563    pending_completions: ArrayVec<PendingCompletion, 2>,
1564    next_pending_completion_id: usize,
1565    current_completion: Option<CurrentEditPrediction>,
1566    /// None if this is entirely disabled for this provider
1567    provider_data_collection: ProviderDataCollection,
1568    last_request_timestamp: Instant,
1569}
1570
1571impl ZetaEditPredictionProvider {
1572    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1573
1574    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1575        Self {
1576            zeta,
1577            pending_completions: ArrayVec::new(),
1578            next_pending_completion_id: 0,
1579            current_completion: None,
1580            provider_data_collection,
1581            last_request_timestamp: Instant::now(),
1582        }
1583    }
1584}
1585
1586impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1587    fn name() -> &'static str {
1588        "zed-predict"
1589    }
1590
1591    fn display_name() -> &'static str {
1592        "Zed's Edit Predictions"
1593    }
1594
1595    fn show_completions_in_menu() -> bool {
1596        true
1597    }
1598
1599    fn show_tab_accept_marker() -> bool {
1600        true
1601    }
1602
1603    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1604        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1605
1606        if self.provider_data_collection.is_data_collection_enabled(cx) {
1607            DataCollectionState::Enabled {
1608                is_project_open_source,
1609            }
1610        } else {
1611            DataCollectionState::Disabled {
1612                is_project_open_source,
1613            }
1614        }
1615    }
1616
1617    fn toggle_data_collection(&mut self, cx: &mut App) {
1618        self.provider_data_collection.toggle(cx);
1619    }
1620
1621    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1622        self.zeta.read(cx).usage(cx)
1623    }
1624
1625    fn is_enabled(
1626        &self,
1627        _buffer: &Entity<Buffer>,
1628        _cursor_position: language::Anchor,
1629        _cx: &App,
1630    ) -> bool {
1631        true
1632    }
1633    fn is_refreshing(&self) -> bool {
1634        !self.pending_completions.is_empty()
1635    }
1636
1637    fn refresh(
1638        &mut self,
1639        project: Option<Entity<Project>>,
1640        buffer: Entity<Buffer>,
1641        position: language::Anchor,
1642        _debounce: bool,
1643        cx: &mut Context<Self>,
1644    ) {
1645        if self.zeta.read(cx).update_required {
1646            return;
1647        }
1648
1649        if self
1650            .zeta
1651            .read(cx)
1652            .user_store
1653            .read_with(cx, |user_store, _cx| {
1654                user_store.account_too_young() || user_store.has_overdue_invoices()
1655            })
1656        {
1657            return;
1658        }
1659
1660        if let Some(current_completion) = self.current_completion.as_ref() {
1661            let snapshot = buffer.read(cx).snapshot();
1662            if current_completion
1663                .completion
1664                .interpolate(&snapshot)
1665                .is_some()
1666            {
1667                return;
1668            }
1669        }
1670
1671        let pending_completion_id = self.next_pending_completion_id;
1672        self.next_pending_completion_id += 1;
1673        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1674        let last_request_timestamp = self.last_request_timestamp;
1675
1676        let task = cx.spawn(async move |this, cx| {
1677            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1678                .checked_duration_since(Instant::now())
1679            {
1680                cx.background_executor().timer(timeout).await;
1681            }
1682
1683            let completion_request = this.update(cx, |this, cx| {
1684                this.last_request_timestamp = Instant::now();
1685                this.zeta.update(cx, |zeta, cx| {
1686                    zeta.request_completion(
1687                        project.as_ref(),
1688                        &buffer,
1689                        position,
1690                        can_collect_data,
1691                        cx,
1692                    )
1693                })
1694            });
1695
1696            let completion = match completion_request {
1697                Ok(completion_request) => {
1698                    let completion_request = completion_request.await;
1699                    completion_request.map(|c| {
1700                        c.map(|completion| CurrentEditPrediction {
1701                            buffer_id: buffer.entity_id(),
1702                            completion,
1703                        })
1704                    })
1705                }
1706                Err(error) => Err(error),
1707            };
1708            let Some(new_completion) = completion
1709                .context("edit prediction failed")
1710                .log_err()
1711                .flatten()
1712            else {
1713                this.update(cx, |this, cx| {
1714                    if this.pending_completions[0].id == pending_completion_id {
1715                        this.pending_completions.remove(0);
1716                    } else {
1717                        this.pending_completions.clear();
1718                    }
1719
1720                    cx.notify();
1721                })
1722                .ok();
1723                return;
1724            };
1725
1726            this.update(cx, |this, cx| {
1727                if this.pending_completions[0].id == pending_completion_id {
1728                    this.pending_completions.remove(0);
1729                } else {
1730                    this.pending_completions.clear();
1731                }
1732
1733                if let Some(old_completion) = this.current_completion.as_ref() {
1734                    let snapshot = buffer.read(cx).snapshot();
1735                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1736                        this.zeta.update(cx, |zeta, cx| {
1737                            zeta.completion_shown(&new_completion.completion, cx);
1738                        });
1739                        this.current_completion = Some(new_completion);
1740                    }
1741                } else {
1742                    this.zeta.update(cx, |zeta, cx| {
1743                        zeta.completion_shown(&new_completion.completion, cx);
1744                    });
1745                    this.current_completion = Some(new_completion);
1746                }
1747
1748                cx.notify();
1749            })
1750            .ok();
1751        });
1752
1753        // We always maintain at most two pending completions. When we already
1754        // have two, we replace the newest one.
1755        if self.pending_completions.len() <= 1 {
1756            self.pending_completions.push(PendingCompletion {
1757                id: pending_completion_id,
1758                _task: task,
1759            });
1760        } else if self.pending_completions.len() == 2 {
1761            self.pending_completions.pop();
1762            self.pending_completions.push(PendingCompletion {
1763                id: pending_completion_id,
1764                _task: task,
1765            });
1766        }
1767    }
1768
1769    fn cycle(
1770        &mut self,
1771        _buffer: Entity<Buffer>,
1772        _cursor_position: language::Anchor,
1773        _direction: edit_prediction::Direction,
1774        _cx: &mut Context<Self>,
1775    ) {
1776        // Right now we don't support cycling.
1777    }
1778
1779    fn accept(&mut self, cx: &mut Context<Self>) {
1780        let completion_id = self
1781            .current_completion
1782            .as_ref()
1783            .map(|completion| completion.completion.id);
1784        if let Some(completion_id) = completion_id {
1785            self.zeta
1786                .update(cx, |zeta, cx| {
1787                    zeta.accept_edit_prediction(completion_id, cx)
1788                })
1789                .detach();
1790        }
1791        self.pending_completions.clear();
1792    }
1793
1794    fn discard(&mut self, _cx: &mut Context<Self>) {
1795        self.pending_completions.clear();
1796        self.current_completion.take();
1797    }
1798
1799    fn suggest(
1800        &mut self,
1801        buffer: &Entity<Buffer>,
1802        cursor_position: language::Anchor,
1803        cx: &mut Context<Self>,
1804    ) -> Option<edit_prediction::EditPrediction> {
1805        let CurrentEditPrediction {
1806            buffer_id,
1807            completion,
1808            ..
1809        } = self.current_completion.as_mut()?;
1810
1811        // Invalidate previous completion if it was generated for a different buffer.
1812        if *buffer_id != buffer.entity_id() {
1813            self.current_completion.take();
1814            return None;
1815        }
1816
1817        let buffer = buffer.read(cx);
1818        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1819            self.current_completion.take();
1820            return None;
1821        };
1822
1823        let cursor_row = cursor_position.to_point(buffer).row;
1824        let (closest_edit_ix, (closest_edit_range, _)) =
1825            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1826                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1827                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1828                cmp::min(distance_from_start, distance_from_end)
1829            })?;
1830
1831        let mut edit_start_ix = closest_edit_ix;
1832        for (range, _) in edits[..edit_start_ix].iter().rev() {
1833            let distance_from_closest_edit =
1834                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1835            if distance_from_closest_edit <= 1 {
1836                edit_start_ix -= 1;
1837            } else {
1838                break;
1839            }
1840        }
1841
1842        let mut edit_end_ix = closest_edit_ix + 1;
1843        for (range, _) in &edits[edit_end_ix..] {
1844            let distance_from_closest_edit =
1845                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1846            if distance_from_closest_edit <= 1 {
1847                edit_end_ix += 1;
1848            } else {
1849                break;
1850            }
1851        }
1852
1853        Some(edit_prediction::EditPrediction {
1854            id: Some(completion.id.to_string().into()),
1855            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1856            edit_preview: Some(completion.edit_preview.clone()),
1857        })
1858    }
1859}
1860
1861fn tokens_for_bytes(bytes: usize) -> usize {
1862    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1863    /// intentionally low to err on the side of underestimating limits.
1864    const BYTES_PER_TOKEN_GUESS: usize = 3;
1865    bytes / BYTES_PER_TOKEN_GUESS
1866}
1867
1868#[cfg(test)]
1869mod tests {
1870    use client::UserStore;
1871    use client::test::FakeServer;
1872    use clock::FakeSystemClock;
1873    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1874    use gpui::TestAppContext;
1875    use http_client::FakeHttpClient;
1876    use indoc::indoc;
1877    use language::Point;
1878    use settings::SettingsStore;
1879
1880    use super::*;
1881
1882    #[gpui::test]
1883    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1884        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1885        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1886            to_completion_edits(
1887                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1888                &buffer,
1889                cx,
1890            )
1891            .into()
1892        });
1893
1894        let edit_preview = cx
1895            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1896            .await;
1897
1898        let completion = EditPrediction {
1899            edits,
1900            edit_preview,
1901            path: Path::new("").into(),
1902            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1903            id: EditPredictionId(Uuid::new_v4()),
1904            excerpt_range: 0..0,
1905            cursor_offset: 0,
1906            input_outline: "".into(),
1907            input_events: "".into(),
1908            input_excerpt: "".into(),
1909            output_excerpt: "".into(),
1910            buffer_snapshotted_at: Instant::now(),
1911            response_received_at: Instant::now(),
1912        };
1913
1914        cx.update(|cx| {
1915            assert_eq!(
1916                from_completion_edits(
1917                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1918                    &buffer,
1919                    cx
1920                ),
1921                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1922            );
1923
1924            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1925            assert_eq!(
1926                from_completion_edits(
1927                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1928                    &buffer,
1929                    cx
1930                ),
1931                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1932            );
1933
1934            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1935            assert_eq!(
1936                from_completion_edits(
1937                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1938                    &buffer,
1939                    cx
1940                ),
1941                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1942            );
1943
1944            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1945            assert_eq!(
1946                from_completion_edits(
1947                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1948                    &buffer,
1949                    cx
1950                ),
1951                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1952            );
1953
1954            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1955            assert_eq!(
1956                from_completion_edits(
1957                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1958                    &buffer,
1959                    cx
1960                ),
1961                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1962            );
1963
1964            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1965            assert_eq!(
1966                from_completion_edits(
1967                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1968                    &buffer,
1969                    cx
1970                ),
1971                vec![(9..11, "".to_string())]
1972            );
1973
1974            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1975            assert_eq!(
1976                from_completion_edits(
1977                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1978                    &buffer,
1979                    cx
1980                ),
1981                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1982            );
1983
1984            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1985            assert_eq!(
1986                from_completion_edits(
1987                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1988                    &buffer,
1989                    cx
1990                ),
1991                vec![(4..4, "M".to_string())]
1992            );
1993
1994            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1995            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1996        })
1997    }
1998
1999    #[gpui::test]
2000    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2001        cx.update(|cx| {
2002            let settings_store = SettingsStore::test(cx);
2003            cx.set_global(settings_store);
2004            client::init_settings(cx);
2005        });
2006
2007        let edits = edits_for_prediction(
2008            indoc! {"
2009                fn main() {
2010                    let word_1 = \"lorem\";
2011                    let range = word.len()..word.len();
2012                }
2013            "},
2014            indoc! {"
2015                <|editable_region_start|>
2016                fn main() {
2017                    let word_1 = \"lorem\";
2018                    let range = word_1.len()..word_1.len();
2019                }
2020
2021                <|editable_region_end|>
2022            "},
2023            cx,
2024        )
2025        .await;
2026        assert_eq!(
2027            edits,
2028            [
2029                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2030                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2031            ]
2032        );
2033
2034        let edits = edits_for_prediction(
2035            indoc! {"
2036                fn main() {
2037                    let story = \"the quick\"
2038                }
2039            "},
2040            indoc! {"
2041                <|editable_region_start|>
2042                fn main() {
2043                    let story = \"the quick brown fox jumps over the lazy dog\";
2044                }
2045
2046                <|editable_region_end|>
2047            "},
2048            cx,
2049        )
2050        .await;
2051        assert_eq!(
2052            edits,
2053            [
2054                (
2055                    Point::new(1, 26)..Point::new(1, 26),
2056                    " brown fox jumps over the lazy dog".to_string()
2057                ),
2058                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2059            ]
2060        );
2061    }
2062
2063    #[gpui::test]
2064    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2065        cx.update(|cx| {
2066            let settings_store = SettingsStore::test(cx);
2067            cx.set_global(settings_store);
2068            client::init_settings(cx);
2069        });
2070
2071        let buffer_content = "lorem\n";
2072        let completion_response = indoc! {"
2073            ```animals.js
2074            <|start_of_file|>
2075            <|editable_region_start|>
2076            lorem
2077            ipsum
2078            <|editable_region_end|>
2079            ```"};
2080
2081        let http_client = FakeHttpClient::create(move |req| async move {
2082            match (req.method(), req.uri().path()) {
2083                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2084                    .status(200)
2085                    .body(
2086                        serde_json::to_string(&CreateLlmTokenResponse {
2087                            token: LlmToken("the-llm-token".to_string()),
2088                        })
2089                        .unwrap()
2090                        .into(),
2091                    )
2092                    .unwrap()),
2093                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2094                    .status(200)
2095                    .body(
2096                        serde_json::to_string(&PredictEditsResponse {
2097                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2098                                .unwrap(),
2099                            output_excerpt: completion_response.to_string(),
2100                        })
2101                        .unwrap()
2102                        .into(),
2103                    )
2104                    .unwrap()),
2105                _ => Ok(http_client::Response::builder()
2106                    .status(404)
2107                    .body("Not Found".into())
2108                    .unwrap()),
2109            }
2110        });
2111
2112        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2113        cx.update(|cx| {
2114            RefreshLlmTokenListener::register(client.clone(), cx);
2115        });
2116        // Construct the fake server to authenticate.
2117        let _server = FakeServer::for_client(42, &client, cx).await;
2118        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2119        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2120
2121        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2122        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2123        let completion_task = zeta.update(cx, |zeta, cx| {
2124            zeta.request_completion(None, &buffer, cursor, false, cx)
2125        });
2126
2127        let completion = completion_task.await.unwrap().unwrap();
2128        buffer.update(cx, |buffer, cx| {
2129            buffer.edit(completion.edits.iter().cloned(), None, cx)
2130        });
2131        assert_eq!(
2132            buffer.read_with(cx, |buffer, _| buffer.text()),
2133            "lorem\nipsum"
2134        );
2135    }
2136
2137    async fn edits_for_prediction(
2138        buffer_content: &str,
2139        completion_response: &str,
2140        cx: &mut TestAppContext,
2141    ) -> Vec<(Range<Point>, String)> {
2142        let completion_response = completion_response.to_string();
2143        let http_client = FakeHttpClient::create(move |req| {
2144            let completion = completion_response.clone();
2145            async move {
2146                match (req.method(), req.uri().path()) {
2147                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2148                        .status(200)
2149                        .body(
2150                            serde_json::to_string(&CreateLlmTokenResponse {
2151                                token: LlmToken("the-llm-token".to_string()),
2152                            })
2153                            .unwrap()
2154                            .into(),
2155                        )
2156                        .unwrap()),
2157                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2158                        .status(200)
2159                        .body(
2160                            serde_json::to_string(&PredictEditsResponse {
2161                                request_id: Uuid::new_v4(),
2162                                output_excerpt: completion,
2163                            })
2164                            .unwrap()
2165                            .into(),
2166                        )
2167                        .unwrap()),
2168                    _ => Ok(http_client::Response::builder()
2169                        .status(404)
2170                        .body("Not Found".into())
2171                        .unwrap()),
2172                }
2173            }
2174        });
2175
2176        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2177        cx.update(|cx| {
2178            RefreshLlmTokenListener::register(client.clone(), cx);
2179        });
2180        // Construct the fake server to authenticate.
2181        let _server = FakeServer::for_client(42, &client, cx).await;
2182        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2183        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2184
2185        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2186        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2187        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2188        let completion_task = zeta.update(cx, |zeta, cx| {
2189            zeta.request_completion(None, &buffer, cursor, false, cx)
2190        });
2191
2192        let completion = completion_task.await.unwrap().unwrap();
2193        completion
2194            .edits
2195            .iter()
2196            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2197            .collect::<Vec<_>>()
2198    }
2199
2200    fn to_completion_edits(
2201        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2202        buffer: &Entity<Buffer>,
2203        cx: &App,
2204    ) -> Vec<(Range<Anchor>, String)> {
2205        let buffer = buffer.read(cx);
2206        iterator
2207            .into_iter()
2208            .map(|(range, text)| {
2209                (
2210                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2211                    text,
2212                )
2213            })
2214            .collect()
2215    }
2216
2217    fn from_completion_edits(
2218        editor_edits: &[(Range<Anchor>, String)],
2219        buffer: &Entity<Buffer>,
2220        cx: &App,
2221    ) -> Vec<(Range<usize>, String)> {
2222        let buffer = buffer.read(cx);
2223        editor_edits
2224            .iter()
2225            .map(|(range, text)| {
2226                (
2227                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2228                    text.clone(),
2229                )
2230            })
2231            .collect()
2232    }
2233
2234    #[ctor::ctor]
2235    fn init_logger() {
2236        zlog::init_test();
2237    }
2238}