zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Maximum number of edit predictions to store for feedback.
  79const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  80
  81actions!(
  82    edit_prediction,
  83    [
  84        /// Clears the edit prediction history.
  85        ClearHistory
  86    ]
  87);
  88
  89#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  90pub struct EditPredictionId(Uuid);
  91
  92impl From<EditPredictionId> for gpui::ElementId {
  93    fn from(value: EditPredictionId) -> Self {
  94        gpui::ElementId::Uuid(value.0)
  95    }
  96}
  97
  98impl std::fmt::Display for EditPredictionId {
  99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 100        write!(f, "{}", self.0)
 101    }
 102}
 103
 104struct ZedPredictUpsell;
 105
 106impl Dismissable for ZedPredictUpsell {
 107    const KEY: &'static str = "dismissed-edit-predict-upsell";
 108
 109    fn dismissed() -> bool {
 110        // To make this backwards compatible with older versions of Zed, we
 111        // check if the user has seen the previous Edit Prediction Onboarding
 112        // before, by checking the data collection choice which was written to
 113        // the database once the user clicked on "Accept and Enable"
 114        if KEY_VALUE_STORE
 115            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 116            .log_err()
 117            .is_some_and(|s| s.is_some())
 118        {
 119            return true;
 120        }
 121
 122        KEY_VALUE_STORE
 123            .read_kvp(Self::KEY)
 124            .log_err()
 125            .is_some_and(|s| s.is_some())
 126    }
 127}
 128
 129pub fn should_show_upsell_modal() -> bool {
 130    !ZedPredictUpsell::dismissed()
 131}
 132
 133#[derive(Clone)]
 134struct ZetaGlobal(Entity<Zeta>);
 135
 136impl Global for ZetaGlobal {}
 137
 138#[derive(Clone)]
 139pub struct EditPrediction {
 140    id: EditPredictionId,
 141    path: Arc<Path>,
 142    excerpt_range: Range<usize>,
 143    cursor_offset: usize,
 144    edits: Arc<[(Range<Anchor>, String)]>,
 145    snapshot: BufferSnapshot,
 146    edit_preview: EditPreview,
 147    input_outline: Arc<str>,
 148    input_events: Arc<str>,
 149    input_excerpt: Arc<str>,
 150    output_excerpt: Arc<str>,
 151    buffer_snapshotted_at: Instant,
 152    response_received_at: Instant,
 153}
 154
 155impl EditPrediction {
 156    fn latency(&self) -> Duration {
 157        self.response_received_at
 158            .duration_since(self.buffer_snapshotted_at)
 159    }
 160
 161    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 162        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 163    }
 164}
 165
 166fn interpolate(
 167    old_snapshot: &BufferSnapshot,
 168    new_snapshot: &BufferSnapshot,
 169    current_edits: Arc<[(Range<Anchor>, String)]>,
 170) -> Option<Vec<(Range<Anchor>, String)>> {
 171    let mut edits = Vec::new();
 172
 173    let mut model_edits = current_edits.iter().peekable();
 174    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 175        while let Some((model_old_range, _)) = model_edits.peek() {
 176            let model_old_range = model_old_range.to_offset(old_snapshot);
 177            if model_old_range.end < user_edit.old.start {
 178                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 179                edits.push((model_old_range.clone(), model_new_text.clone()));
 180            } else {
 181                break;
 182            }
 183        }
 184
 185        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 186            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 187            if user_edit.old == model_old_offset_range {
 188                let user_new_text = new_snapshot
 189                    .text_for_range(user_edit.new.clone())
 190                    .collect::<String>();
 191
 192                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 193                    if !model_suffix.is_empty() {
 194                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 195                        edits.push((anchor..anchor, model_suffix.to_string()));
 196                    }
 197
 198                    model_edits.next();
 199                    continue;
 200                }
 201            }
 202        }
 203
 204        return None;
 205    }
 206
 207    edits.extend(model_edits.cloned());
 208
 209    if edits.is_empty() { None } else { Some(edits) }
 210}
 211
 212impl std::fmt::Debug for EditPrediction {
 213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 214        f.debug_struct("EditPrediction")
 215            .field("id", &self.id)
 216            .field("path", &self.path)
 217            .field("edits", &self.edits)
 218            .finish_non_exhaustive()
 219    }
 220}
 221
 222pub struct Zeta {
 223    workspace: WeakEntity<Workspace>,
 224    client: Arc<Client>,
 225    events: VecDeque<Event>,
 226    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 227    shown_completions: VecDeque<EditPrediction>,
 228    rated_completions: HashSet<EditPredictionId>,
 229    data_collection_choice: Entity<DataCollectionChoice>,
 230    llm_token: LlmApiToken,
 231    _llm_token_subscription: Subscription,
 232    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 233    update_required: bool,
 234    user_store: Entity<UserStore>,
 235    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 236    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 237}
 238
 239impl Zeta {
 240    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 241        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 242    }
 243
 244    pub fn register(
 245        workspace: Option<Entity<Workspace>>,
 246        worktree: Option<Entity<Worktree>>,
 247        client: Arc<Client>,
 248        user_store: Entity<UserStore>,
 249        cx: &mut App,
 250    ) -> Entity<Self> {
 251        let this = Self::global(cx).unwrap_or_else(|| {
 252            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 253            cx.set_global(ZetaGlobal(entity.clone()));
 254            entity
 255        });
 256
 257        this.update(cx, move |this, cx| {
 258            if let Some(worktree) = worktree {
 259                let worktree_id = worktree.read(cx).id();
 260                this.license_detection_watchers
 261                    .entry(worktree_id)
 262                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 263            }
 264        });
 265
 266        this
 267    }
 268
 269    pub fn clear_history(&mut self) {
 270        self.events.clear();
 271    }
 272
 273    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 274        self.user_store.read(cx).edit_prediction_usage()
 275    }
 276
 277    fn new(
 278        workspace: Option<Entity<Workspace>>,
 279        client: Arc<Client>,
 280        user_store: Entity<UserStore>,
 281        cx: &mut Context<Self>,
 282    ) -> Self {
 283        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 284
 285        let data_collection_choice = Self::load_data_collection_choices();
 286        let data_collection_choice = cx.new(|_| data_collection_choice);
 287
 288        if let Some(workspace) = &workspace {
 289            cx.subscribe(
 290                &workspace.read(cx).project().clone(),
 291                |this, _workspace, event, _cx| match event {
 292                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 293                        this.push_recent_project_entry(*project_entry_id)
 294                    }
 295                    _ => {}
 296                },
 297            )
 298            .detach();
 299        }
 300
 301        Self {
 302            workspace: workspace.map_or_else(
 303                || WeakEntity::new_invalid(),
 304                |workspace| workspace.downgrade(),
 305            ),
 306            client,
 307            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 308            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 309            rated_completions: HashSet::default(),
 310            registered_buffers: HashMap::default(),
 311            data_collection_choice,
 312            llm_token: LlmApiToken::default(),
 313            _llm_token_subscription: cx.subscribe(
 314                &refresh_llm_token_listener,
 315                |this, _listener, _event, cx| {
 316                    let client = this.client.clone();
 317                    let llm_token = this.llm_token.clone();
 318                    cx.spawn(async move |_this, _cx| {
 319                        llm_token.refresh(&client).await?;
 320                        anyhow::Ok(())
 321                    })
 322                    .detach_and_log_err(cx);
 323                },
 324            ),
 325            update_required: false,
 326            license_detection_watchers: HashMap::default(),
 327            user_store,
 328            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 329        }
 330    }
 331
 332    fn push_event(&mut self, event: Event) {
 333        if let Some(Event::BufferChange {
 334            new_snapshot: last_new_snapshot,
 335            timestamp: last_timestamp,
 336            ..
 337        }) = self.events.back_mut()
 338        {
 339            // Coalesce edits for the same buffer when they happen one after the other.
 340            let Event::BufferChange {
 341                old_snapshot,
 342                new_snapshot,
 343                timestamp,
 344            } = &event;
 345
 346            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 347                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 348                && old_snapshot.version == last_new_snapshot.version
 349            {
 350                *last_new_snapshot = new_snapshot.clone();
 351                *last_timestamp = *timestamp;
 352                return;
 353            }
 354        }
 355
 356        if self.events.len() >= MAX_EVENT_COUNT {
 357            // These are halved instead of popping to improve prompt caching.
 358            self.events.drain(..MAX_EVENT_COUNT / 2);
 359        }
 360
 361        self.events.push_back(event);
 362    }
 363
 364    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 365        let buffer_id = buffer.entity_id();
 366        let weak_buffer = buffer.downgrade();
 367
 368        if let std::collections::hash_map::Entry::Vacant(entry) =
 369            self.registered_buffers.entry(buffer_id)
 370        {
 371            let snapshot = buffer.read(cx).snapshot();
 372
 373            entry.insert(RegisteredBuffer {
 374                snapshot,
 375                _subscriptions: [
 376                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 377                        this.handle_buffer_event(buffer, event, cx);
 378                    }),
 379                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 380                        this.registered_buffers.remove(&weak_buffer.entity_id());
 381                    }),
 382                ],
 383            });
 384        };
 385    }
 386
 387    fn handle_buffer_event(
 388        &mut self,
 389        buffer: Entity<Buffer>,
 390        event: &language::BufferEvent,
 391        cx: &mut Context<Self>,
 392    ) {
 393        if let language::BufferEvent::Edited = event {
 394            self.report_changes_for_buffer(&buffer, cx);
 395        }
 396    }
 397
 398    fn request_completion_impl<F, R>(
 399        &mut self,
 400        workspace: Option<Entity<Workspace>>,
 401        project: Option<&Entity<Project>>,
 402        buffer: &Entity<Buffer>,
 403        cursor: language::Anchor,
 404        can_collect_data: bool,
 405        cx: &mut Context<Self>,
 406        perform_predict_edits: F,
 407    ) -> Task<Result<Option<EditPrediction>>>
 408    where
 409        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 410        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 411            + Send
 412            + 'static,
 413    {
 414        let buffer = buffer.clone();
 415        let buffer_snapshotted_at = Instant::now();
 416        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 417        let zeta = cx.entity();
 418        let events = self.events.clone();
 419        let client = self.client.clone();
 420        let llm_token = self.llm_token.clone();
 421        let app_version = AppVersion::global(cx);
 422
 423        let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
 424            (can_collect_data, project, snapshot.file())
 425            && let Some(repository) =
 426                git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 427        {
 428            let repository = repository.read(cx);
 429            let git_info = make_predict_edits_git_info(repository);
 430            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
 431            (git_info, Some(recent_files))
 432        } else {
 433            (None, None)
 434        };
 435
 436        let full_path: Arc<Path> = snapshot
 437            .file()
 438            .map(|f| Arc::from(f.full_path(cx).as_path()))
 439            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 440        let full_path_str = full_path.to_string_lossy().to_string();
 441        let cursor_point = cursor.to_point(&snapshot);
 442        let cursor_offset = cursor_point.to_offset(&snapshot);
 443        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 444        let gather_task = gather_context(
 445            project,
 446            full_path_str,
 447            &snapshot,
 448            cursor_point,
 449            make_events_prompt,
 450            can_collect_data,
 451            git_info,
 452            recent_files,
 453            cx,
 454        );
 455
 456        cx.spawn(async move |this, cx| {
 457            let GatherContextOutput {
 458                body,
 459                editable_range,
 460            } = gather_task.await?;
 461            let done_gathering_context_at = Instant::now();
 462
 463            log::debug!(
 464                "Events:\n{}\nExcerpt:\n{:?}",
 465                body.input_events,
 466                body.input_excerpt
 467            );
 468
 469            let input_outline = body.outline.clone().unwrap_or_default();
 470            let input_events = body.input_events.clone();
 471            let input_excerpt = body.input_excerpt.clone();
 472
 473            let response = perform_predict_edits(PerformPredictEditsParams {
 474                client,
 475                llm_token,
 476                app_version,
 477                body,
 478            })
 479            .await;
 480            let (response, usage) = match response {
 481                Ok(response) => response,
 482                Err(err) => {
 483                    if err.is::<ZedUpdateRequiredError>() {
 484                        cx.update(|cx| {
 485                            zeta.update(cx, |zeta, _cx| {
 486                                zeta.update_required = true;
 487                            });
 488
 489                            if let Some(workspace) = workspace {
 490                                workspace.update(cx, |workspace, cx| {
 491                                    workspace.show_notification(
 492                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 493                                        cx,
 494                                        |cx| {
 495                                            cx.new(|cx| {
 496                                                ErrorMessagePrompt::new(err.to_string(), cx)
 497                                                    .with_link_button(
 498                                                        "Update Zed",
 499                                                        "https://zed.dev/releases",
 500                                                    )
 501                                            })
 502                                        },
 503                                    );
 504                                });
 505                            }
 506                        })
 507                        .ok();
 508                    }
 509
 510                    return Err(err);
 511                }
 512            };
 513
 514            let received_response_at = Instant::now();
 515            log::debug!("completion response: {}", &response.output_excerpt);
 516
 517            if let Some(usage) = usage {
 518                this.update(cx, |this, cx| {
 519                    this.user_store.update(cx, |user_store, cx| {
 520                        user_store.update_edit_prediction_usage(usage, cx);
 521                    });
 522                })
 523                .ok();
 524            }
 525
 526            let edit_prediction = Self::process_completion_response(
 527                response,
 528                buffer,
 529                &snapshot,
 530                editable_range,
 531                cursor_offset,
 532                full_path,
 533                input_outline,
 534                input_events,
 535                input_excerpt,
 536                buffer_snapshotted_at,
 537                cx,
 538            )
 539            .await;
 540
 541            let finished_at = Instant::now();
 542
 543            // record latency for ~1% of requests
 544            if rand::random::<u8>() <= 2 {
 545                telemetry::event!(
 546                    "Edit Prediction Request",
 547                    context_latency = done_gathering_context_at
 548                        .duration_since(buffer_snapshotted_at)
 549                        .as_millis(),
 550                    request_latency = received_response_at
 551                        .duration_since(done_gathering_context_at)
 552                        .as_millis(),
 553                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 554                );
 555            }
 556
 557            edit_prediction
 558        })
 559    }
 560
 561    // Generates several example completions of various states to fill the Zeta completion modal
 562    #[cfg(any(test, feature = "test-support"))]
 563    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 564        use language::Point;
 565
 566        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 567            And maybe a short line
 568
 569            Then a few lines
 570
 571            and then another
 572            "#};
 573
 574        let project = None;
 575        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 576        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 577
 578        let completion_tasks = vec![
 579            self.fake_completion(
 580                project,
 581                &buffer,
 582                position,
 583                PredictEditsResponse {
 584                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 585                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 586a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 587[here's an edit]
 588And maybe a short line
 589Then a few lines
 590and then another
 591{EDITABLE_REGION_END_MARKER}
 592                        ", ),
 593                },
 594                cx,
 595            ),
 596            self.fake_completion(
 597                project,
 598                &buffer,
 599                position,
 600                PredictEditsResponse {
 601                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 602                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 603a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 604And maybe a short line
 605[and another edit]
 606Then a few lines
 607and then another
 608{EDITABLE_REGION_END_MARKER}
 609                        "#),
 610                },
 611                cx,
 612            ),
 613            self.fake_completion(
 614                project,
 615                &buffer,
 616                position,
 617                PredictEditsResponse {
 618                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 619                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 620a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 621And maybe a short line
 622
 623Then a few lines
 624
 625and then another
 626{EDITABLE_REGION_END_MARKER}
 627                        "#),
 628                },
 629                cx,
 630            ),
 631            self.fake_completion(
 632                project,
 633                &buffer,
 634                position,
 635                PredictEditsResponse {
 636                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 637                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 638a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 639And maybe a short line
 640
 641Then a few lines
 642
 643and then another
 644{EDITABLE_REGION_END_MARKER}
 645                        "#),
 646                },
 647                cx,
 648            ),
 649            self.fake_completion(
 650                project,
 651                &buffer,
 652                position,
 653                PredictEditsResponse {
 654                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 655                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 656a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 657And maybe a short line
 658Then a few lines
 659[a third completion]
 660and then another
 661{EDITABLE_REGION_END_MARKER}
 662                        "#),
 663                },
 664                cx,
 665            ),
 666            self.fake_completion(
 667                project,
 668                &buffer,
 669                position,
 670                PredictEditsResponse {
 671                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 672                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 673a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 674And maybe a short line
 675and then another
 676[fourth completion example]
 677{EDITABLE_REGION_END_MARKER}
 678                        "#),
 679                },
 680                cx,
 681            ),
 682            self.fake_completion(
 683                project,
 684                &buffer,
 685                position,
 686                PredictEditsResponse {
 687                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 688                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 689a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 690And maybe a short line
 691Then a few lines
 692and then another
 693[fifth and final completion]
 694{EDITABLE_REGION_END_MARKER}
 695                        "#),
 696                },
 697                cx,
 698            ),
 699        ];
 700
 701        cx.spawn(async move |zeta, cx| {
 702            for task in completion_tasks {
 703                task.await.unwrap();
 704            }
 705
 706            zeta.update(cx, |zeta, _cx| {
 707                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 708                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 709            })
 710            .ok();
 711        })
 712    }
 713
 714    #[cfg(any(test, feature = "test-support"))]
 715    pub fn fake_completion(
 716        &mut self,
 717        project: Option<&Entity<Project>>,
 718        buffer: &Entity<Buffer>,
 719        position: language::Anchor,
 720        response: PredictEditsResponse,
 721        cx: &mut Context<Self>,
 722    ) -> Task<Result<Option<EditPrediction>>> {
 723        use std::future::ready;
 724
 725        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
 726            ready(Ok((response, None)))
 727        })
 728    }
 729
 730    pub fn request_completion(
 731        &mut self,
 732        project: Option<&Entity<Project>>,
 733        buffer: &Entity<Buffer>,
 734        position: language::Anchor,
 735        can_collect_data: bool,
 736        cx: &mut Context<Self>,
 737    ) -> Task<Result<Option<EditPrediction>>> {
 738        self.request_completion_impl(
 739            self.workspace.upgrade(),
 740            project,
 741            buffer,
 742            position,
 743            can_collect_data,
 744            cx,
 745            Self::perform_predict_edits,
 746        )
 747    }
 748
 749    pub fn perform_predict_edits(
 750        params: PerformPredictEditsParams,
 751    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 752        async move {
 753            let PerformPredictEditsParams {
 754                client,
 755                llm_token,
 756                app_version,
 757                body,
 758                ..
 759            } = params;
 760
 761            let http_client = client.http_client();
 762            let mut token = llm_token.acquire(&client).await?;
 763            let mut did_retry = false;
 764
 765            loop {
 766                let request_builder = http_client::Request::builder().method(Method::POST);
 767                let request_builder =
 768                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 769                        request_builder.uri(predict_edits_url)
 770                    } else {
 771                        request_builder.uri(
 772                            http_client
 773                                .build_zed_llm_url("/predict_edits/v2", &[])?
 774                                .as_ref(),
 775                        )
 776                    };
 777                let request = request_builder
 778                    .header("Content-Type", "application/json")
 779                    .header("Authorization", format!("Bearer {}", token))
 780                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 781                    .body(serde_json::to_string(&body)?.into())?;
 782
 783                let mut response = http_client.send(request).await?;
 784
 785                if let Some(minimum_required_version) = response
 786                    .headers()
 787                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 788                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 789                {
 790                    anyhow::ensure!(
 791                        app_version >= minimum_required_version,
 792                        ZedUpdateRequiredError {
 793                            minimum_version: minimum_required_version
 794                        }
 795                    );
 796                }
 797
 798                if response.status().is_success() {
 799                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 800
 801                    let mut body = String::new();
 802                    response.body_mut().read_to_string(&mut body).await?;
 803                    return Ok((serde_json::from_str(&body)?, usage));
 804                } else if !did_retry
 805                    && response
 806                        .headers()
 807                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 808                        .is_some()
 809                {
 810                    did_retry = true;
 811                    token = llm_token.refresh(&client).await?;
 812                } else {
 813                    let mut body = String::new();
 814                    response.body_mut().read_to_string(&mut body).await?;
 815                    anyhow::bail!(
 816                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 817                        response.status(),
 818                        body
 819                    );
 820                }
 821            }
 822        }
 823    }
 824
 825    fn accept_edit_prediction(
 826        &mut self,
 827        request_id: EditPredictionId,
 828        cx: &mut Context<Self>,
 829    ) -> Task<Result<()>> {
 830        let client = self.client.clone();
 831        let llm_token = self.llm_token.clone();
 832        let app_version = AppVersion::global(cx);
 833        cx.spawn(async move |this, cx| {
 834            let http_client = client.http_client();
 835            let mut response = llm_token_retry(&llm_token, &client, |token| {
 836                let request_builder = http_client::Request::builder().method(Method::POST);
 837                let request_builder =
 838                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 839                        request_builder.uri(accept_prediction_url)
 840                    } else {
 841                        request_builder.uri(
 842                            http_client
 843                                .build_zed_llm_url("/predict_edits/accept", &[])?
 844                                .as_ref(),
 845                        )
 846                    };
 847                Ok(request_builder
 848                    .header("Content-Type", "application/json")
 849                    .header("Authorization", format!("Bearer {}", token))
 850                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 851                    .body(
 852                        serde_json::to_string(&AcceptEditPredictionBody {
 853                            request_id: request_id.0,
 854                        })?
 855                        .into(),
 856                    )?)
 857            })
 858            .await?;
 859
 860            if let Some(minimum_required_version) = response
 861                .headers()
 862                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 863                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 864                && app_version < minimum_required_version
 865            {
 866                return Err(anyhow!(ZedUpdateRequiredError {
 867                    minimum_version: minimum_required_version
 868                }));
 869            }
 870
 871            if response.status().is_success() {
 872                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 873                    this.update(cx, |this, cx| {
 874                        this.user_store.update(cx, |user_store, cx| {
 875                            user_store.update_edit_prediction_usage(usage, cx);
 876                        });
 877                    })?;
 878                }
 879
 880                Ok(())
 881            } else {
 882                let mut body = String::new();
 883                response.body_mut().read_to_string(&mut body).await?;
 884                Err(anyhow!(
 885                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 886                    response.status(),
 887                    body
 888                ))
 889            }
 890        })
 891    }
 892
 893    fn process_completion_response(
 894        prediction_response: PredictEditsResponse,
 895        buffer: Entity<Buffer>,
 896        snapshot: &BufferSnapshot,
 897        editable_range: Range<usize>,
 898        cursor_offset: usize,
 899        path: Arc<Path>,
 900        input_outline: String,
 901        input_events: String,
 902        input_excerpt: String,
 903        buffer_snapshotted_at: Instant,
 904        cx: &AsyncApp,
 905    ) -> Task<Result<Option<EditPrediction>>> {
 906        let snapshot = snapshot.clone();
 907        let request_id = prediction_response.request_id;
 908        let output_excerpt = prediction_response.output_excerpt;
 909        cx.spawn(async move |cx| {
 910            let output_excerpt: Arc<str> = output_excerpt.into();
 911
 912            let edits: Arc<[(Range<Anchor>, String)]> = cx
 913                .background_spawn({
 914                    let output_excerpt = output_excerpt.clone();
 915                    let editable_range = editable_range.clone();
 916                    let snapshot = snapshot.clone();
 917                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 918                })
 919                .await?
 920                .into();
 921
 922            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 923                let edits = edits.clone();
 924                |buffer, cx| {
 925                    let new_snapshot = buffer.snapshot();
 926                    let edits: Arc<[(Range<Anchor>, String)]> =
 927                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 928                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 929                }
 930            })?
 931            else {
 932                return anyhow::Ok(None);
 933            };
 934
 935            let edit_preview = edit_preview.await;
 936
 937            Ok(Some(EditPrediction {
 938                id: EditPredictionId(request_id),
 939                path,
 940                excerpt_range: editable_range,
 941                cursor_offset,
 942                edits,
 943                edit_preview,
 944                snapshot,
 945                input_outline: input_outline.into(),
 946                input_events: input_events.into(),
 947                input_excerpt: input_excerpt.into(),
 948                output_excerpt,
 949                buffer_snapshotted_at,
 950                response_received_at: Instant::now(),
 951            }))
 952        })
 953    }
 954
 955    fn parse_edits(
 956        output_excerpt: Arc<str>,
 957        editable_range: Range<usize>,
 958        snapshot: &BufferSnapshot,
 959    ) -> Result<Vec<(Range<Anchor>, String)>> {
 960        let content = output_excerpt.replace(CURSOR_MARKER, "");
 961
 962        let start_markers = content
 963            .match_indices(EDITABLE_REGION_START_MARKER)
 964            .collect::<Vec<_>>();
 965        anyhow::ensure!(
 966            start_markers.len() == 1,
 967            "expected exactly one start marker, found {}",
 968            start_markers.len()
 969        );
 970
 971        let end_markers = content
 972            .match_indices(EDITABLE_REGION_END_MARKER)
 973            .collect::<Vec<_>>();
 974        anyhow::ensure!(
 975            end_markers.len() == 1,
 976            "expected exactly one end marker, found {}",
 977            end_markers.len()
 978        );
 979
 980        let sof_markers = content
 981            .match_indices(START_OF_FILE_MARKER)
 982            .collect::<Vec<_>>();
 983        anyhow::ensure!(
 984            sof_markers.len() <= 1,
 985            "expected at most one start-of-file marker, found {}",
 986            sof_markers.len()
 987        );
 988
 989        let codefence_start = start_markers[0].0;
 990        let content = &content[codefence_start..];
 991
 992        let newline_ix = content.find('\n').context("could not find newline")?;
 993        let content = &content[newline_ix + 1..];
 994
 995        let codefence_end = content
 996            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 997            .context("could not find end marker")?;
 998        let new_text = &content[..codefence_end];
 999
1000        let old_text = snapshot
1001            .text_for_range(editable_range.clone())
1002            .collect::<String>();
1003
1004        Ok(Self::compute_edits(
1005            old_text,
1006            new_text,
1007            editable_range.start,
1008            snapshot,
1009        ))
1010    }
1011
1012    pub fn compute_edits(
1013        old_text: String,
1014        new_text: &str,
1015        offset: usize,
1016        snapshot: &BufferSnapshot,
1017    ) -> Vec<(Range<Anchor>, String)> {
1018        text_diff(&old_text, new_text)
1019            .into_iter()
1020            .map(|(mut old_range, new_text)| {
1021                old_range.start += offset;
1022                old_range.end += offset;
1023
1024                let prefix_len = common_prefix(
1025                    snapshot.chars_for_range(old_range.clone()),
1026                    new_text.chars(),
1027                );
1028                old_range.start += prefix_len;
1029
1030                let suffix_len = common_prefix(
1031                    snapshot.reversed_chars_for_range(old_range.clone()),
1032                    new_text[prefix_len..].chars().rev(),
1033                );
1034                old_range.end = old_range.end.saturating_sub(suffix_len);
1035
1036                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1037                let range = if old_range.is_empty() {
1038                    let anchor = snapshot.anchor_after(old_range.start);
1039                    anchor..anchor
1040                } else {
1041                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1042                };
1043                (range, new_text)
1044            })
1045            .collect()
1046    }
1047
1048    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1049        self.rated_completions.contains(&completion_id)
1050    }
1051
1052    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1053        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1054            let completion = self.shown_completions.pop_back().unwrap();
1055            self.rated_completions.remove(&completion.id);
1056        }
1057        self.shown_completions.push_front(completion.clone());
1058        cx.notify();
1059    }
1060
1061    pub fn rate_completion(
1062        &mut self,
1063        completion: &EditPrediction,
1064        rating: EditPredictionRating,
1065        feedback: String,
1066        cx: &mut Context<Self>,
1067    ) {
1068        self.rated_completions.insert(completion.id);
1069        telemetry::event!(
1070            "Edit Prediction Rated",
1071            rating,
1072            input_events = completion.input_events,
1073            input_excerpt = completion.input_excerpt,
1074            input_outline = completion.input_outline,
1075            output_excerpt = completion.output_excerpt,
1076            feedback
1077        );
1078        self.client.telemetry().flush_events().detach();
1079        cx.notify();
1080    }
1081
1082    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1083        self.shown_completions.iter()
1084    }
1085
1086    pub fn shown_completions_len(&self) -> usize {
1087        self.shown_completions.len()
1088    }
1089
1090    fn report_changes_for_buffer(
1091        &mut self,
1092        buffer: &Entity<Buffer>,
1093        cx: &mut Context<Self>,
1094    ) -> BufferSnapshot {
1095        self.register_buffer(buffer, cx);
1096
1097        let registered_buffer = self
1098            .registered_buffers
1099            .get_mut(&buffer.entity_id())
1100            .unwrap();
1101        let new_snapshot = buffer.read(cx).snapshot();
1102
1103        if new_snapshot.version != registered_buffer.snapshot.version {
1104            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1105            self.push_event(Event::BufferChange {
1106                old_snapshot,
1107                new_snapshot: new_snapshot.clone(),
1108                timestamp: Instant::now(),
1109            });
1110        }
1111
1112        new_snapshot
1113    }
1114
1115    fn load_data_collection_choices() -> DataCollectionChoice {
1116        let choice = KEY_VALUE_STORE
1117            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1118            .log_err()
1119            .flatten();
1120
1121        match choice.as_deref() {
1122            Some("true") => DataCollectionChoice::Enabled,
1123            Some("false") => DataCollectionChoice::Disabled,
1124            Some(_) => {
1125                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1126                DataCollectionChoice::NotAnswered
1127            }
1128            None => DataCollectionChoice::NotAnswered,
1129        }
1130    }
1131
1132    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1133        let now = Instant::now();
1134        if let Some(existing_ix) = self
1135            .recent_project_entries
1136            .iter()
1137            .rposition(|(id, _)| *id == project_entry_id)
1138        {
1139            self.recent_project_entries.remove(existing_ix);
1140        }
1141        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1142            self.recent_project_entries.pop_front();
1143        }
1144        self.recent_project_entries
1145            .push_back((project_entry_id, now));
1146    }
1147
1148    fn recent_files(
1149        &mut self,
1150        now: &Instant,
1151        repository: &Repository,
1152        cx: &Context<Self>,
1153    ) -> Vec<PredictEditsRecentFile> {
1154        let Ok(project) = self
1155            .workspace
1156            .read_with(cx, |workspace, _cx| workspace.project().clone())
1157        else {
1158            return Vec::new();
1159        };
1160        let mut results = Vec::new();
1161        for ix in (0..self.recent_project_entries.len()).rev() {
1162            let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1163            if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1164                && let worktree = worktree.read(cx)
1165                && let Some(entry) = worktree.entry_for_id(*entry_id)
1166                && entry.is_file()
1167                && entry.is_created()
1168                && !entry.is_ignored
1169                && !entry.is_private
1170                && !entry.is_external
1171                && !entry.is_fifo
1172            {
1173                let project_path = ProjectPath {
1174                    worktree_id: worktree.id(),
1175                    path: entry.path.clone(),
1176                };
1177                let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1178                else {
1179                    // entry not removed since queries involving other repositories might occur later
1180                    continue;
1181                };
1182                let Some(repo_path_str) = repo_path.to_str() else {
1183                    // paths may not be valid UTF-8
1184                    self.recent_project_entries.remove(ix);
1185                    continue;
1186                };
1187                if let Some(file_status) = repository.status_for_path(&repo_path) {
1188                    if file_status.is_ignored() || file_status.is_untracked() {
1189                        // entry not removed because it may belong to a nested repository
1190                        continue;
1191                    }
1192                }
1193                let Ok(active_to_now_ms) =
1194                    now.duration_since(*last_active_at).as_millis().try_into()
1195                else {
1196                    self.recent_project_entries.remove(ix);
1197                    continue;
1198                };
1199                results.push(PredictEditsRecentFile {
1200                    repo_path: repo_path_str.to_string(),
1201                    active_to_now_ms,
1202                });
1203            } else {
1204                self.recent_project_entries.remove(ix);
1205            }
1206        }
1207        results
1208    }
1209}
1210
1211pub struct PerformPredictEditsParams {
1212    pub client: Arc<Client>,
1213    pub llm_token: LlmApiToken,
1214    pub app_version: SemanticVersion,
1215    pub body: PredictEditsBody,
1216}
1217
1218#[derive(Error, Debug)]
1219#[error(
1220    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1221)]
1222pub struct ZedUpdateRequiredError {
1223    minimum_version: SemanticVersion,
1224}
1225
1226fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1227    a.zip(b)
1228        .take_while(|(a, b)| a == b)
1229        .map(|(a, _)| a.len_utf8())
1230        .sum()
1231}
1232
1233fn git_repository_for_file(
1234    project: &Entity<Project>,
1235    project_path: &ProjectPath,
1236    cx: &App,
1237) -> Option<Entity<Repository>> {
1238    let git_store = project.read(cx).git_store().read(cx);
1239    git_store
1240        .repository_and_path_for_project_path(project_path, cx)
1241        .map(|(repo, _repo_path)| repo)
1242}
1243
1244fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
1245    let head_sha = repository
1246        .head_commit
1247        .as_ref()
1248        .map(|head_commit| head_commit.sha.to_string());
1249    let remote_origin_url = repository.remote_origin_url.clone();
1250    let remote_upstream_url = repository.remote_upstream_url.clone();
1251    if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1252        return None;
1253    }
1254    Some(PredictEditsGitInfo {
1255        head_sha,
1256        remote_origin_url,
1257        remote_upstream_url,
1258    })
1259}
1260
1261pub struct GatherContextOutput {
1262    pub body: PredictEditsBody,
1263    pub editable_range: Range<usize>,
1264}
1265
1266pub fn gather_context(
1267    project: Option<&Entity<Project>>,
1268    full_path_str: String,
1269    snapshot: &BufferSnapshot,
1270    cursor_point: language::Point,
1271    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1272    can_collect_data: bool,
1273    git_info: Option<PredictEditsGitInfo>,
1274    recent_files: Option<Vec<PredictEditsRecentFile>>,
1275    cx: &App,
1276) -> Task<Result<GatherContextOutput>> {
1277    let local_lsp_store =
1278        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1279    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1280        if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1281            snapshot
1282                .diagnostic_groups(None)
1283                .into_iter()
1284                .filter_map(|(language_server_id, diagnostic_group)| {
1285                    let language_server =
1286                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1287                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1288                    let language_server_name = language_server.name().to_string();
1289                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1290                    Some((language_server_name, serialized))
1291                })
1292                .collect::<Vec<_>>()
1293        } else {
1294            Vec::new()
1295        };
1296
1297    cx.background_spawn({
1298        let snapshot = snapshot.clone();
1299        async move {
1300            let diagnostic_groups = if diagnostic_groups.is_empty()
1301                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1302            {
1303                None
1304            } else {
1305                Some(diagnostic_groups)
1306            };
1307
1308            let input_excerpt = excerpt_for_cursor_position(
1309                cursor_point,
1310                &full_path_str,
1311                &snapshot,
1312                MAX_REWRITE_TOKENS,
1313                MAX_CONTEXT_TOKENS,
1314            );
1315            let input_events = make_events_prompt();
1316            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1317
1318            let body = PredictEditsBody {
1319                input_events,
1320                input_excerpt: input_excerpt.prompt,
1321                can_collect_data,
1322                diagnostic_groups,
1323                git_info,
1324                outline: None,
1325                speculated_output: None,
1326                recent_files,
1327            };
1328
1329            Ok(GatherContextOutput {
1330                body,
1331                editable_range,
1332            })
1333        }
1334    })
1335}
1336
1337fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1338    let mut result = String::new();
1339    for event in events.iter().rev() {
1340        let event_string = event.to_prompt();
1341        let event_tokens = tokens_for_bytes(event_string.len());
1342        if event_tokens > remaining_tokens {
1343            break;
1344        }
1345
1346        if !result.is_empty() {
1347            result.insert_str(0, "\n\n");
1348        }
1349        result.insert_str(0, &event_string);
1350        remaining_tokens -= event_tokens;
1351    }
1352    result
1353}
1354
1355struct RegisteredBuffer {
1356    snapshot: BufferSnapshot,
1357    _subscriptions: [gpui::Subscription; 2],
1358}
1359
1360#[derive(Clone)]
1361pub enum Event {
1362    BufferChange {
1363        old_snapshot: BufferSnapshot,
1364        new_snapshot: BufferSnapshot,
1365        timestamp: Instant,
1366    },
1367}
1368
1369impl Event {
1370    fn to_prompt(&self) -> String {
1371        match self {
1372            Event::BufferChange {
1373                old_snapshot,
1374                new_snapshot,
1375                ..
1376            } => {
1377                let mut prompt = String::new();
1378
1379                let old_path = old_snapshot
1380                    .file()
1381                    .map(|f| f.path().as_ref())
1382                    .unwrap_or(Path::new("untitled"));
1383                let new_path = new_snapshot
1384                    .file()
1385                    .map(|f| f.path().as_ref())
1386                    .unwrap_or(Path::new("untitled"));
1387                if old_path != new_path {
1388                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1389                }
1390
1391                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1392                if !diff.is_empty() {
1393                    write!(
1394                        prompt,
1395                        "User edited {:?}:\n```diff\n{}\n```",
1396                        new_path, diff
1397                    )
1398                    .unwrap();
1399                }
1400
1401                prompt
1402            }
1403        }
1404    }
1405}
1406
1407#[derive(Debug, Clone)]
1408struct CurrentEditPrediction {
1409    buffer_id: EntityId,
1410    completion: EditPrediction,
1411}
1412
1413impl CurrentEditPrediction {
1414    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1415        if self.buffer_id != old_completion.buffer_id {
1416            return true;
1417        }
1418
1419        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1420            return true;
1421        };
1422        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1423            return false;
1424        };
1425
1426        if old_edits.len() == 1 && new_edits.len() == 1 {
1427            let (old_range, old_text) = &old_edits[0];
1428            let (new_range, new_text) = &new_edits[0];
1429            new_range == old_range && new_text.starts_with(old_text)
1430        } else {
1431            true
1432        }
1433    }
1434}
1435
1436struct PendingCompletion {
1437    id: usize,
1438    _task: Task<()>,
1439}
1440
1441#[derive(Debug, Clone, Copy)]
1442pub enum DataCollectionChoice {
1443    NotAnswered,
1444    Enabled,
1445    Disabled,
1446}
1447
1448impl DataCollectionChoice {
1449    pub fn is_enabled(self) -> bool {
1450        match self {
1451            Self::Enabled => true,
1452            Self::NotAnswered | Self::Disabled => false,
1453        }
1454    }
1455
1456    pub fn is_answered(self) -> bool {
1457        match self {
1458            Self::Enabled | Self::Disabled => true,
1459            Self::NotAnswered => false,
1460        }
1461    }
1462
1463    pub fn toggle(&self) -> DataCollectionChoice {
1464        match self {
1465            Self::Enabled => Self::Disabled,
1466            Self::Disabled => Self::Enabled,
1467            Self::NotAnswered => Self::Enabled,
1468        }
1469    }
1470}
1471
1472impl From<bool> for DataCollectionChoice {
1473    fn from(value: bool) -> Self {
1474        match value {
1475            true => DataCollectionChoice::Enabled,
1476            false => DataCollectionChoice::Disabled,
1477        }
1478    }
1479}
1480
1481pub struct ProviderDataCollection {
1482    /// When set to None, data collection is not possible in the provider buffer
1483    choice: Option<Entity<DataCollectionChoice>>,
1484    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1485}
1486
1487impl ProviderDataCollection {
1488    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1489        let choice_and_watcher = buffer.and_then(|buffer| {
1490            let file = buffer.read(cx).file()?;
1491
1492            if !file.is_local() || file.is_private() {
1493                return None;
1494            }
1495
1496            let zeta = zeta.read(cx);
1497            let choice = zeta.data_collection_choice.clone();
1498
1499            let license_detection_watcher = zeta
1500                .license_detection_watchers
1501                .get(&file.worktree_id(cx))
1502                .cloned()?;
1503
1504            Some((choice, license_detection_watcher))
1505        });
1506
1507        if let Some((choice, watcher)) = choice_and_watcher {
1508            ProviderDataCollection {
1509                choice: Some(choice),
1510                license_detection_watcher: Some(watcher),
1511            }
1512        } else {
1513            ProviderDataCollection {
1514                choice: None,
1515                license_detection_watcher: None,
1516            }
1517        }
1518    }
1519
1520    pub fn can_collect_data(&self, cx: &App) -> bool {
1521        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1522    }
1523
1524    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1525        self.choice
1526            .as_ref()
1527            .is_some_and(|choice| choice.read(cx).is_enabled())
1528    }
1529
1530    fn is_project_open_source(&self) -> bool {
1531        self.license_detection_watcher
1532            .as_ref()
1533            .is_some_and(|watcher| watcher.is_project_open_source())
1534    }
1535
1536    pub fn toggle(&mut self, cx: &mut App) {
1537        if let Some(choice) = self.choice.as_mut() {
1538            let new_choice = choice.update(cx, |choice, _cx| {
1539                let new_choice = choice.toggle();
1540                *choice = new_choice;
1541                new_choice
1542            });
1543
1544            db::write_and_log(cx, move || {
1545                KEY_VALUE_STORE.write_kvp(
1546                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1547                    new_choice.is_enabled().to_string(),
1548                )
1549            });
1550        }
1551    }
1552}
1553
1554async fn llm_token_retry(
1555    llm_token: &LlmApiToken,
1556    client: &Arc<Client>,
1557    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1558) -> Result<Response<AsyncBody>> {
1559    let mut did_retry = false;
1560    let http_client = client.http_client();
1561    let mut token = llm_token.acquire(client).await?;
1562    loop {
1563        let request = build_request(token.clone())?;
1564        let response = http_client.send(request).await?;
1565
1566        if !did_retry
1567            && !response.status().is_success()
1568            && response
1569                .headers()
1570                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1571                .is_some()
1572        {
1573            did_retry = true;
1574            token = llm_token.refresh(client).await?;
1575            continue;
1576        }
1577
1578        return Ok(response);
1579    }
1580}
1581
1582pub struct ZetaEditPredictionProvider {
1583    zeta: Entity<Zeta>,
1584    pending_completions: ArrayVec<PendingCompletion, 2>,
1585    next_pending_completion_id: usize,
1586    current_completion: Option<CurrentEditPrediction>,
1587    /// None if this is entirely disabled for this provider
1588    provider_data_collection: ProviderDataCollection,
1589    last_request_timestamp: Instant,
1590}
1591
1592impl ZetaEditPredictionProvider {
1593    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1594
1595    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1596        Self {
1597            zeta,
1598            pending_completions: ArrayVec::new(),
1599            next_pending_completion_id: 0,
1600            current_completion: None,
1601            provider_data_collection,
1602            last_request_timestamp: Instant::now(),
1603        }
1604    }
1605}
1606
1607impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1608    fn name() -> &'static str {
1609        "zed-predict"
1610    }
1611
1612    fn display_name() -> &'static str {
1613        "Zed's Edit Predictions"
1614    }
1615
1616    fn show_completions_in_menu() -> bool {
1617        true
1618    }
1619
1620    fn show_tab_accept_marker() -> bool {
1621        true
1622    }
1623
1624    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1625        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1626
1627        if self.provider_data_collection.is_data_collection_enabled(cx) {
1628            DataCollectionState::Enabled {
1629                is_project_open_source,
1630            }
1631        } else {
1632            DataCollectionState::Disabled {
1633                is_project_open_source,
1634            }
1635        }
1636    }
1637
1638    fn toggle_data_collection(&mut self, cx: &mut App) {
1639        self.provider_data_collection.toggle(cx);
1640    }
1641
1642    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1643        self.zeta.read(cx).usage(cx)
1644    }
1645
1646    fn is_enabled(
1647        &self,
1648        _buffer: &Entity<Buffer>,
1649        _cursor_position: language::Anchor,
1650        _cx: &App,
1651    ) -> bool {
1652        true
1653    }
1654    fn is_refreshing(&self) -> bool {
1655        !self.pending_completions.is_empty()
1656    }
1657
1658    fn refresh(
1659        &mut self,
1660        project: Option<Entity<Project>>,
1661        buffer: Entity<Buffer>,
1662        position: language::Anchor,
1663        _debounce: bool,
1664        cx: &mut Context<Self>,
1665    ) {
1666        if self.zeta.read(cx).update_required {
1667            return;
1668        }
1669
1670        if self
1671            .zeta
1672            .read(cx)
1673            .user_store
1674            .read_with(cx, |user_store, _cx| {
1675                user_store.account_too_young() || user_store.has_overdue_invoices()
1676            })
1677        {
1678            return;
1679        }
1680
1681        if let Some(current_completion) = self.current_completion.as_ref() {
1682            let snapshot = buffer.read(cx).snapshot();
1683            if current_completion
1684                .completion
1685                .interpolate(&snapshot)
1686                .is_some()
1687            {
1688                return;
1689            }
1690        }
1691
1692        let pending_completion_id = self.next_pending_completion_id;
1693        self.next_pending_completion_id += 1;
1694        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1695        let last_request_timestamp = self.last_request_timestamp;
1696
1697        let task = cx.spawn(async move |this, cx| {
1698            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1699                .checked_duration_since(Instant::now())
1700            {
1701                cx.background_executor().timer(timeout).await;
1702            }
1703
1704            let completion_request = this.update(cx, |this, cx| {
1705                this.last_request_timestamp = Instant::now();
1706                this.zeta.update(cx, |zeta, cx| {
1707                    zeta.request_completion(
1708                        project.as_ref(),
1709                        &buffer,
1710                        position,
1711                        can_collect_data,
1712                        cx,
1713                    )
1714                })
1715            });
1716
1717            let completion = match completion_request {
1718                Ok(completion_request) => {
1719                    let completion_request = completion_request.await;
1720                    completion_request.map(|c| {
1721                        c.map(|completion| CurrentEditPrediction {
1722                            buffer_id: buffer.entity_id(),
1723                            completion,
1724                        })
1725                    })
1726                }
1727                Err(error) => Err(error),
1728            };
1729            let Some(new_completion) = completion
1730                .context("edit prediction failed")
1731                .log_err()
1732                .flatten()
1733            else {
1734                this.update(cx, |this, cx| {
1735                    if this.pending_completions[0].id == pending_completion_id {
1736                        this.pending_completions.remove(0);
1737                    } else {
1738                        this.pending_completions.clear();
1739                    }
1740
1741                    cx.notify();
1742                })
1743                .ok();
1744                return;
1745            };
1746
1747            this.update(cx, |this, cx| {
1748                if this.pending_completions[0].id == pending_completion_id {
1749                    this.pending_completions.remove(0);
1750                } else {
1751                    this.pending_completions.clear();
1752                }
1753
1754                if let Some(old_completion) = this.current_completion.as_ref() {
1755                    let snapshot = buffer.read(cx).snapshot();
1756                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1757                        this.zeta.update(cx, |zeta, cx| {
1758                            zeta.completion_shown(&new_completion.completion, cx);
1759                        });
1760                        this.current_completion = Some(new_completion);
1761                    }
1762                } else {
1763                    this.zeta.update(cx, |zeta, cx| {
1764                        zeta.completion_shown(&new_completion.completion, cx);
1765                    });
1766                    this.current_completion = Some(new_completion);
1767                }
1768
1769                cx.notify();
1770            })
1771            .ok();
1772        });
1773
1774        // We always maintain at most two pending completions. When we already
1775        // have two, we replace the newest one.
1776        if self.pending_completions.len() <= 1 {
1777            self.pending_completions.push(PendingCompletion {
1778                id: pending_completion_id,
1779                _task: task,
1780            });
1781        } else if self.pending_completions.len() == 2 {
1782            self.pending_completions.pop();
1783            self.pending_completions.push(PendingCompletion {
1784                id: pending_completion_id,
1785                _task: task,
1786            });
1787        }
1788    }
1789
1790    fn cycle(
1791        &mut self,
1792        _buffer: Entity<Buffer>,
1793        _cursor_position: language::Anchor,
1794        _direction: edit_prediction::Direction,
1795        _cx: &mut Context<Self>,
1796    ) {
1797        // Right now we don't support cycling.
1798    }
1799
1800    fn accept(&mut self, cx: &mut Context<Self>) {
1801        let completion_id = self
1802            .current_completion
1803            .as_ref()
1804            .map(|completion| completion.completion.id);
1805        if let Some(completion_id) = completion_id {
1806            self.zeta
1807                .update(cx, |zeta, cx| {
1808                    zeta.accept_edit_prediction(completion_id, cx)
1809                })
1810                .detach();
1811        }
1812        self.pending_completions.clear();
1813    }
1814
1815    fn discard(&mut self, _cx: &mut Context<Self>) {
1816        self.pending_completions.clear();
1817        self.current_completion.take();
1818    }
1819
1820    fn suggest(
1821        &mut self,
1822        buffer: &Entity<Buffer>,
1823        cursor_position: language::Anchor,
1824        cx: &mut Context<Self>,
1825    ) -> Option<edit_prediction::EditPrediction> {
1826        let CurrentEditPrediction {
1827            buffer_id,
1828            completion,
1829            ..
1830        } = self.current_completion.as_mut()?;
1831
1832        // Invalidate previous completion if it was generated for a different buffer.
1833        if *buffer_id != buffer.entity_id() {
1834            self.current_completion.take();
1835            return None;
1836        }
1837
1838        let buffer = buffer.read(cx);
1839        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1840            self.current_completion.take();
1841            return None;
1842        };
1843
1844        let cursor_row = cursor_position.to_point(buffer).row;
1845        let (closest_edit_ix, (closest_edit_range, _)) =
1846            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1847                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1848                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1849                cmp::min(distance_from_start, distance_from_end)
1850            })?;
1851
1852        let mut edit_start_ix = closest_edit_ix;
1853        for (range, _) in edits[..edit_start_ix].iter().rev() {
1854            let distance_from_closest_edit =
1855                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1856            if distance_from_closest_edit <= 1 {
1857                edit_start_ix -= 1;
1858            } else {
1859                break;
1860            }
1861        }
1862
1863        let mut edit_end_ix = closest_edit_ix + 1;
1864        for (range, _) in &edits[edit_end_ix..] {
1865            let distance_from_closest_edit =
1866                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1867            if distance_from_closest_edit <= 1 {
1868                edit_end_ix += 1;
1869            } else {
1870                break;
1871            }
1872        }
1873
1874        Some(edit_prediction::EditPrediction {
1875            id: Some(completion.id.to_string().into()),
1876            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1877            edit_preview: Some(completion.edit_preview.clone()),
1878        })
1879    }
1880}
1881
1882fn tokens_for_bytes(bytes: usize) -> usize {
1883    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1884    /// intentionally low to err on the side of underestimating limits.
1885    const BYTES_PER_TOKEN_GUESS: usize = 3;
1886    bytes / BYTES_PER_TOKEN_GUESS
1887}
1888
1889#[cfg(test)]
1890mod tests {
1891    use client::UserStore;
1892    use client::test::FakeServer;
1893    use clock::FakeSystemClock;
1894    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1895    use gpui::TestAppContext;
1896    use http_client::FakeHttpClient;
1897    use indoc::indoc;
1898    use language::Point;
1899    use settings::SettingsStore;
1900
1901    use super::*;
1902
1903    #[gpui::test]
1904    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1905        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1906        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1907            to_completion_edits(
1908                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1909                &buffer,
1910                cx,
1911            )
1912            .into()
1913        });
1914
1915        let edit_preview = cx
1916            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1917            .await;
1918
1919        let completion = EditPrediction {
1920            edits,
1921            edit_preview,
1922            path: Path::new("").into(),
1923            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1924            id: EditPredictionId(Uuid::new_v4()),
1925            excerpt_range: 0..0,
1926            cursor_offset: 0,
1927            input_outline: "".into(),
1928            input_events: "".into(),
1929            input_excerpt: "".into(),
1930            output_excerpt: "".into(),
1931            buffer_snapshotted_at: Instant::now(),
1932            response_received_at: Instant::now(),
1933        };
1934
1935        cx.update(|cx| {
1936            assert_eq!(
1937                from_completion_edits(
1938                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1939                    &buffer,
1940                    cx
1941                ),
1942                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1943            );
1944
1945            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1946            assert_eq!(
1947                from_completion_edits(
1948                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1949                    &buffer,
1950                    cx
1951                ),
1952                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1953            );
1954
1955            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1956            assert_eq!(
1957                from_completion_edits(
1958                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1959                    &buffer,
1960                    cx
1961                ),
1962                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1963            );
1964
1965            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1966            assert_eq!(
1967                from_completion_edits(
1968                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1969                    &buffer,
1970                    cx
1971                ),
1972                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1973            );
1974
1975            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1976            assert_eq!(
1977                from_completion_edits(
1978                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1979                    &buffer,
1980                    cx
1981                ),
1982                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1983            );
1984
1985            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1986            assert_eq!(
1987                from_completion_edits(
1988                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1989                    &buffer,
1990                    cx
1991                ),
1992                vec![(9..11, "".to_string())]
1993            );
1994
1995            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1996            assert_eq!(
1997                from_completion_edits(
1998                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1999                    &buffer,
2000                    cx
2001                ),
2002                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2003            );
2004
2005            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2006            assert_eq!(
2007                from_completion_edits(
2008                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2009                    &buffer,
2010                    cx
2011                ),
2012                vec![(4..4, "M".to_string())]
2013            );
2014
2015            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2016            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2017        })
2018    }
2019
2020    #[gpui::test]
2021    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2022        cx.update(|cx| {
2023            let settings_store = SettingsStore::test(cx);
2024            cx.set_global(settings_store);
2025            client::init_settings(cx);
2026        });
2027
2028        let edits = edits_for_prediction(
2029            indoc! {"
2030                fn main() {
2031                    let word_1 = \"lorem\";
2032                    let range = word.len()..word.len();
2033                }
2034            "},
2035            indoc! {"
2036                <|editable_region_start|>
2037                fn main() {
2038                    let word_1 = \"lorem\";
2039                    let range = word_1.len()..word_1.len();
2040                }
2041
2042                <|editable_region_end|>
2043            "},
2044            cx,
2045        )
2046        .await;
2047        assert_eq!(
2048            edits,
2049            [
2050                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2051                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2052            ]
2053        );
2054
2055        let edits = edits_for_prediction(
2056            indoc! {"
2057                fn main() {
2058                    let story = \"the quick\"
2059                }
2060            "},
2061            indoc! {"
2062                <|editable_region_start|>
2063                fn main() {
2064                    let story = \"the quick brown fox jumps over the lazy dog\";
2065                }
2066
2067                <|editable_region_end|>
2068            "},
2069            cx,
2070        )
2071        .await;
2072        assert_eq!(
2073            edits,
2074            [
2075                (
2076                    Point::new(1, 26)..Point::new(1, 26),
2077                    " brown fox jumps over the lazy dog".to_string()
2078                ),
2079                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2080            ]
2081        );
2082    }
2083
2084    #[gpui::test]
2085    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2086        cx.update(|cx| {
2087            let settings_store = SettingsStore::test(cx);
2088            cx.set_global(settings_store);
2089            client::init_settings(cx);
2090        });
2091
2092        let buffer_content = "lorem\n";
2093        let completion_response = indoc! {"
2094            ```animals.js
2095            <|start_of_file|>
2096            <|editable_region_start|>
2097            lorem
2098            ipsum
2099            <|editable_region_end|>
2100            ```"};
2101
2102        let http_client = FakeHttpClient::create(move |req| async move {
2103            match (req.method(), req.uri().path()) {
2104                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2105                    .status(200)
2106                    .body(
2107                        serde_json::to_string(&CreateLlmTokenResponse {
2108                            token: LlmToken("the-llm-token".to_string()),
2109                        })
2110                        .unwrap()
2111                        .into(),
2112                    )
2113                    .unwrap()),
2114                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2115                    .status(200)
2116                    .body(
2117                        serde_json::to_string(&PredictEditsResponse {
2118                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2119                                .unwrap(),
2120                            output_excerpt: completion_response.to_string(),
2121                        })
2122                        .unwrap()
2123                        .into(),
2124                    )
2125                    .unwrap()),
2126                _ => Ok(http_client::Response::builder()
2127                    .status(404)
2128                    .body("Not Found".into())
2129                    .unwrap()),
2130            }
2131        });
2132
2133        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2134        cx.update(|cx| {
2135            RefreshLlmTokenListener::register(client.clone(), cx);
2136        });
2137        // Construct the fake server to authenticate.
2138        let _server = FakeServer::for_client(42, &client, cx).await;
2139        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2140        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2141
2142        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2143        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2144        let completion_task = zeta.update(cx, |zeta, cx| {
2145            zeta.request_completion(None, &buffer, cursor, false, cx)
2146        });
2147
2148        let completion = completion_task.await.unwrap().unwrap();
2149        buffer.update(cx, |buffer, cx| {
2150            buffer.edit(completion.edits.iter().cloned(), None, cx)
2151        });
2152        assert_eq!(
2153            buffer.read_with(cx, |buffer, _| buffer.text()),
2154            "lorem\nipsum"
2155        );
2156    }
2157
2158    async fn edits_for_prediction(
2159        buffer_content: &str,
2160        completion_response: &str,
2161        cx: &mut TestAppContext,
2162    ) -> Vec<(Range<Point>, String)> {
2163        let completion_response = completion_response.to_string();
2164        let http_client = FakeHttpClient::create(move |req| {
2165            let completion = completion_response.clone();
2166            async move {
2167                match (req.method(), req.uri().path()) {
2168                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2169                        .status(200)
2170                        .body(
2171                            serde_json::to_string(&CreateLlmTokenResponse {
2172                                token: LlmToken("the-llm-token".to_string()),
2173                            })
2174                            .unwrap()
2175                            .into(),
2176                        )
2177                        .unwrap()),
2178                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2179                        .status(200)
2180                        .body(
2181                            serde_json::to_string(&PredictEditsResponse {
2182                                request_id: Uuid::new_v4(),
2183                                output_excerpt: completion,
2184                            })
2185                            .unwrap()
2186                            .into(),
2187                        )
2188                        .unwrap()),
2189                    _ => Ok(http_client::Response::builder()
2190                        .status(404)
2191                        .body("Not Found".into())
2192                        .unwrap()),
2193                }
2194            }
2195        });
2196
2197        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2198        cx.update(|cx| {
2199            RefreshLlmTokenListener::register(client.clone(), cx);
2200        });
2201        // Construct the fake server to authenticate.
2202        let _server = FakeServer::for_client(42, &client, cx).await;
2203        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2204        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2205
2206        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2207        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2208        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2209        let completion_task = zeta.update(cx, |zeta, cx| {
2210            zeta.request_completion(None, &buffer, cursor, false, cx)
2211        });
2212
2213        let completion = completion_task.await.unwrap().unwrap();
2214        completion
2215            .edits
2216            .iter()
2217            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2218            .collect::<Vec<_>>()
2219    }
2220
2221    fn to_completion_edits(
2222        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2223        buffer: &Entity<Buffer>,
2224        cx: &App,
2225    ) -> Vec<(Range<Anchor>, String)> {
2226        let buffer = buffer.read(cx);
2227        iterator
2228            .into_iter()
2229            .map(|(range, text)| {
2230                (
2231                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2232                    text,
2233                )
2234            })
2235            .collect()
2236    }
2237
2238    fn from_completion_edits(
2239        editor_edits: &[(Range<Anchor>, String)],
2240        buffer: &Entity<Buffer>,
2241        cx: &App,
2242    ) -> Vec<(Range<usize>, String)> {
2243        let buffer = buffer.read(cx);
2244        editor_edits
2245            .iter()
2246            .map(|(range, text)| {
2247                (
2248                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2249                    text.clone(),
2250                )
2251            })
2252            .collect()
2253    }
2254
2255    #[ctor::ctor]
2256    fn init_logger() {
2257        zlog::init_test();
2258    }
2259}