zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Maximum file path length to include in recent files list.
  79const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
  80
  81/// Maximum number of edit predictions to store for feedback.
  82const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  83
  84actions!(
  85    edit_prediction,
  86    [
  87        /// Clears the edit prediction history.
  88        ClearHistory
  89    ]
  90);
  91
  92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  93pub struct EditPredictionId(Uuid);
  94
  95impl From<EditPredictionId> for gpui::ElementId {
  96    fn from(value: EditPredictionId) -> Self {
  97        gpui::ElementId::Uuid(value.0)
  98    }
  99}
 100
 101impl std::fmt::Display for EditPredictionId {
 102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 103        write!(f, "{}", self.0)
 104    }
 105}
 106
 107struct ZedPredictUpsell;
 108
 109impl Dismissable for ZedPredictUpsell {
 110    const KEY: &'static str = "dismissed-edit-predict-upsell";
 111
 112    fn dismissed() -> bool {
 113        // To make this backwards compatible with older versions of Zed, we
 114        // check if the user has seen the previous Edit Prediction Onboarding
 115        // before, by checking the data collection choice which was written to
 116        // the database once the user clicked on "Accept and Enable"
 117        if KEY_VALUE_STORE
 118            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 119            .log_err()
 120            .is_some_and(|s| s.is_some())
 121        {
 122            return true;
 123        }
 124
 125        KEY_VALUE_STORE
 126            .read_kvp(Self::KEY)
 127            .log_err()
 128            .is_some_and(|s| s.is_some())
 129    }
 130}
 131
 132pub fn should_show_upsell_modal() -> bool {
 133    !ZedPredictUpsell::dismissed()
 134}
 135
 136#[derive(Clone)]
 137struct ZetaGlobal(Entity<Zeta>);
 138
 139impl Global for ZetaGlobal {}
 140
 141#[derive(Clone)]
 142pub struct EditPrediction {
 143    id: EditPredictionId,
 144    path: Arc<Path>,
 145    excerpt_range: Range<usize>,
 146    cursor_offset: usize,
 147    edits: Arc<[(Range<Anchor>, String)]>,
 148    snapshot: BufferSnapshot,
 149    edit_preview: EditPreview,
 150    input_outline: Arc<str>,
 151    input_events: Arc<str>,
 152    input_excerpt: Arc<str>,
 153    output_excerpt: Arc<str>,
 154    buffer_snapshotted_at: Instant,
 155    response_received_at: Instant,
 156}
 157
 158impl EditPrediction {
 159    fn latency(&self) -> Duration {
 160        self.response_received_at
 161            .duration_since(self.buffer_snapshotted_at)
 162    }
 163
 164    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 165        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 166    }
 167}
 168
 169fn interpolate(
 170    old_snapshot: &BufferSnapshot,
 171    new_snapshot: &BufferSnapshot,
 172    current_edits: Arc<[(Range<Anchor>, String)]>,
 173) -> Option<Vec<(Range<Anchor>, String)>> {
 174    let mut edits = Vec::new();
 175
 176    let mut model_edits = current_edits.iter().peekable();
 177    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 178        while let Some((model_old_range, _)) = model_edits.peek() {
 179            let model_old_range = model_old_range.to_offset(old_snapshot);
 180            if model_old_range.end < user_edit.old.start {
 181                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 182                edits.push((model_old_range.clone(), model_new_text.clone()));
 183            } else {
 184                break;
 185            }
 186        }
 187
 188        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 189            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 190            if user_edit.old == model_old_offset_range {
 191                let user_new_text = new_snapshot
 192                    .text_for_range(user_edit.new.clone())
 193                    .collect::<String>();
 194
 195                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 196                    if !model_suffix.is_empty() {
 197                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 198                        edits.push((anchor..anchor, model_suffix.to_string()));
 199                    }
 200
 201                    model_edits.next();
 202                    continue;
 203                }
 204            }
 205        }
 206
 207        return None;
 208    }
 209
 210    edits.extend(model_edits.cloned());
 211
 212    if edits.is_empty() { None } else { Some(edits) }
 213}
 214
 215impl std::fmt::Debug for EditPrediction {
 216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 217        f.debug_struct("EditPrediction")
 218            .field("id", &self.id)
 219            .field("path", &self.path)
 220            .field("edits", &self.edits)
 221            .finish_non_exhaustive()
 222    }
 223}
 224
 225pub struct Zeta {
 226    workspace: WeakEntity<Workspace>,
 227    client: Arc<Client>,
 228    events: VecDeque<Event>,
 229    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 230    shown_completions: VecDeque<EditPrediction>,
 231    rated_completions: HashSet<EditPredictionId>,
 232    data_collection_choice: Entity<DataCollectionChoice>,
 233    llm_token: LlmApiToken,
 234    _llm_token_subscription: Subscription,
 235    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 236    update_required: bool,
 237    user_store: Entity<UserStore>,
 238    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 239    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 240}
 241
 242impl Zeta {
 243    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 244        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 245    }
 246
 247    pub fn register(
 248        workspace: Option<Entity<Workspace>>,
 249        worktree: Option<Entity<Worktree>>,
 250        client: Arc<Client>,
 251        user_store: Entity<UserStore>,
 252        cx: &mut App,
 253    ) -> Entity<Self> {
 254        let this = Self::global(cx).unwrap_or_else(|| {
 255            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 256            cx.set_global(ZetaGlobal(entity.clone()));
 257            entity
 258        });
 259
 260        this.update(cx, move |this, cx| {
 261            if let Some(worktree) = worktree {
 262                let worktree_id = worktree.read(cx).id();
 263                this.license_detection_watchers
 264                    .entry(worktree_id)
 265                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 266            }
 267        });
 268
 269        this
 270    }
 271
 272    pub fn clear_history(&mut self) {
 273        self.events.clear();
 274    }
 275
 276    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 277        self.user_store.read(cx).edit_prediction_usage()
 278    }
 279
 280    fn new(
 281        workspace: Option<Entity<Workspace>>,
 282        client: Arc<Client>,
 283        user_store: Entity<UserStore>,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 287
 288        let data_collection_choice = Self::load_data_collection_choices();
 289        let data_collection_choice = cx.new(|_| data_collection_choice);
 290
 291        if let Some(workspace) = &workspace {
 292            cx.subscribe(
 293                &workspace.read(cx).project().clone(),
 294                |this, _workspace, event, _cx| match event {
 295                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 296                        this.push_recent_project_entry(*project_entry_id)
 297                    }
 298                    _ => {}
 299                },
 300            )
 301            .detach();
 302        }
 303
 304        Self {
 305            workspace: workspace.map_or_else(
 306                || WeakEntity::new_invalid(),
 307                |workspace| workspace.downgrade(),
 308            ),
 309            client,
 310            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 311            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 312            rated_completions: HashSet::default(),
 313            registered_buffers: HashMap::default(),
 314            data_collection_choice,
 315            llm_token: LlmApiToken::default(),
 316            _llm_token_subscription: cx.subscribe(
 317                &refresh_llm_token_listener,
 318                |this, _listener, _event, cx| {
 319                    let client = this.client.clone();
 320                    let llm_token = this.llm_token.clone();
 321                    cx.spawn(async move |_this, _cx| {
 322                        llm_token.refresh(&client).await?;
 323                        anyhow::Ok(())
 324                    })
 325                    .detach_and_log_err(cx);
 326                },
 327            ),
 328            update_required: false,
 329            license_detection_watchers: HashMap::default(),
 330            user_store,
 331            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 332        }
 333    }
 334
 335    fn push_event(&mut self, event: Event) {
 336        if let Some(Event::BufferChange {
 337            new_snapshot: last_new_snapshot,
 338            timestamp: last_timestamp,
 339            ..
 340        }) = self.events.back_mut()
 341        {
 342            // Coalesce edits for the same buffer when they happen one after the other.
 343            let Event::BufferChange {
 344                old_snapshot,
 345                new_snapshot,
 346                timestamp,
 347            } = &event;
 348
 349            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 350                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 351                && old_snapshot.version == last_new_snapshot.version
 352            {
 353                *last_new_snapshot = new_snapshot.clone();
 354                *last_timestamp = *timestamp;
 355                return;
 356            }
 357        }
 358
 359        if self.events.len() >= MAX_EVENT_COUNT {
 360            // These are halved instead of popping to improve prompt caching.
 361            self.events.drain(..MAX_EVENT_COUNT / 2);
 362        }
 363
 364        self.events.push_back(event);
 365    }
 366
 367    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 368        let buffer_id = buffer.entity_id();
 369        let weak_buffer = buffer.downgrade();
 370
 371        if let std::collections::hash_map::Entry::Vacant(entry) =
 372            self.registered_buffers.entry(buffer_id)
 373        {
 374            let snapshot = buffer.read(cx).snapshot();
 375
 376            entry.insert(RegisteredBuffer {
 377                snapshot,
 378                _subscriptions: [
 379                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 380                        this.handle_buffer_event(buffer, event, cx);
 381                    }),
 382                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 383                        this.registered_buffers.remove(&weak_buffer.entity_id());
 384                    }),
 385                ],
 386            });
 387        };
 388    }
 389
 390    fn handle_buffer_event(
 391        &mut self,
 392        buffer: Entity<Buffer>,
 393        event: &language::BufferEvent,
 394        cx: &mut Context<Self>,
 395    ) {
 396        if let language::BufferEvent::Edited = event {
 397            self.report_changes_for_buffer(&buffer, cx);
 398        }
 399    }
 400
 401    fn request_completion_impl<F, R>(
 402        &mut self,
 403        workspace: Option<Entity<Workspace>>,
 404        project: Option<&Entity<Project>>,
 405        buffer: &Entity<Buffer>,
 406        cursor: language::Anchor,
 407        can_collect_data: bool,
 408        cx: &mut Context<Self>,
 409        perform_predict_edits: F,
 410    ) -> Task<Result<Option<EditPrediction>>>
 411    where
 412        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 413        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 414            + Send
 415            + 'static,
 416    {
 417        let buffer = buffer.clone();
 418        let buffer_snapshotted_at = Instant::now();
 419        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 420        let zeta = cx.entity();
 421        let events = self.events.clone();
 422        let client = self.client.clone();
 423        let llm_token = self.llm_token.clone();
 424        let app_version = AppVersion::global(cx);
 425
 426        let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
 427            (can_collect_data, project, snapshot.file())
 428            && let Some(repository) =
 429                git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 430        {
 431            let repository = repository.read(cx);
 432            let git_info = make_predict_edits_git_info(repository);
 433            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
 434            (git_info, Some(recent_files))
 435        } else {
 436            (None, None)
 437        };
 438
 439        let full_path: Arc<Path> = snapshot
 440            .file()
 441            .map(|f| Arc::from(f.full_path(cx).as_path()))
 442            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 443        let full_path_str = full_path.to_string_lossy().to_string();
 444        let cursor_point = cursor.to_point(&snapshot);
 445        let cursor_offset = cursor_point.to_offset(&snapshot);
 446        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 447        let gather_task = gather_context(
 448            project,
 449            full_path_str,
 450            &snapshot,
 451            cursor_point,
 452            make_events_prompt,
 453            can_collect_data,
 454            git_info,
 455            recent_files,
 456            cx,
 457        );
 458
 459        cx.spawn(async move |this, cx| {
 460            let GatherContextOutput {
 461                body,
 462                editable_range,
 463            } = gather_task.await?;
 464            let done_gathering_context_at = Instant::now();
 465
 466            log::debug!(
 467                "Events:\n{}\nExcerpt:\n{:?}",
 468                body.input_events,
 469                body.input_excerpt
 470            );
 471
 472            let input_outline = body.outline.clone().unwrap_or_default();
 473            let input_events = body.input_events.clone();
 474            let input_excerpt = body.input_excerpt.clone();
 475
 476            let response = perform_predict_edits(PerformPredictEditsParams {
 477                client,
 478                llm_token,
 479                app_version,
 480                body,
 481            })
 482            .await;
 483            let (response, usage) = match response {
 484                Ok(response) => response,
 485                Err(err) => {
 486                    if err.is::<ZedUpdateRequiredError>() {
 487                        cx.update(|cx| {
 488                            zeta.update(cx, |zeta, _cx| {
 489                                zeta.update_required = true;
 490                            });
 491
 492                            if let Some(workspace) = workspace {
 493                                workspace.update(cx, |workspace, cx| {
 494                                    workspace.show_notification(
 495                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 496                                        cx,
 497                                        |cx| {
 498                                            cx.new(|cx| {
 499                                                ErrorMessagePrompt::new(err.to_string(), cx)
 500                                                    .with_link_button(
 501                                                        "Update Zed",
 502                                                        "https://zed.dev/releases",
 503                                                    )
 504                                            })
 505                                        },
 506                                    );
 507                                });
 508                            }
 509                        })
 510                        .ok();
 511                    }
 512
 513                    return Err(err);
 514                }
 515            };
 516
 517            let received_response_at = Instant::now();
 518            log::debug!("completion response: {}", &response.output_excerpt);
 519
 520            if let Some(usage) = usage {
 521                this.update(cx, |this, cx| {
 522                    this.user_store.update(cx, |user_store, cx| {
 523                        user_store.update_edit_prediction_usage(usage, cx);
 524                    });
 525                })
 526                .ok();
 527            }
 528
 529            let edit_prediction = Self::process_completion_response(
 530                response,
 531                buffer,
 532                &snapshot,
 533                editable_range,
 534                cursor_offset,
 535                full_path,
 536                input_outline,
 537                input_events,
 538                input_excerpt,
 539                buffer_snapshotted_at,
 540                cx,
 541            )
 542            .await;
 543
 544            let finished_at = Instant::now();
 545
 546            // record latency for ~1% of requests
 547            if rand::random::<u8>() <= 2 {
 548                telemetry::event!(
 549                    "Edit Prediction Request",
 550                    context_latency = done_gathering_context_at
 551                        .duration_since(buffer_snapshotted_at)
 552                        .as_millis(),
 553                    request_latency = received_response_at
 554                        .duration_since(done_gathering_context_at)
 555                        .as_millis(),
 556                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 557                );
 558            }
 559
 560            edit_prediction
 561        })
 562    }
 563
 564    // Generates several example completions of various states to fill the Zeta completion modal
 565    #[cfg(any(test, feature = "test-support"))]
 566    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 567        use language::Point;
 568
 569        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 570            And maybe a short line
 571
 572            Then a few lines
 573
 574            and then another
 575            "#};
 576
 577        let project = None;
 578        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 579        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 580
 581        let completion_tasks = vec![
 582            self.fake_completion(
 583                project,
 584                &buffer,
 585                position,
 586                PredictEditsResponse {
 587                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 588                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 589a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 590[here's an edit]
 591And maybe a short line
 592Then a few lines
 593and then another
 594{EDITABLE_REGION_END_MARKER}
 595                        ", ),
 596                },
 597                cx,
 598            ),
 599            self.fake_completion(
 600                project,
 601                &buffer,
 602                position,
 603                PredictEditsResponse {
 604                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 605                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 606a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 607And maybe a short line
 608[and another edit]
 609Then a few lines
 610and then another
 611{EDITABLE_REGION_END_MARKER}
 612                        "#),
 613                },
 614                cx,
 615            ),
 616            self.fake_completion(
 617                project,
 618                &buffer,
 619                position,
 620                PredictEditsResponse {
 621                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 622                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 623a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 624And maybe a short line
 625
 626Then a few lines
 627
 628and then another
 629{EDITABLE_REGION_END_MARKER}
 630                        "#),
 631                },
 632                cx,
 633            ),
 634            self.fake_completion(
 635                project,
 636                &buffer,
 637                position,
 638                PredictEditsResponse {
 639                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 640                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 641a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 642And maybe a short line
 643
 644Then a few lines
 645
 646and then another
 647{EDITABLE_REGION_END_MARKER}
 648                        "#),
 649                },
 650                cx,
 651            ),
 652            self.fake_completion(
 653                project,
 654                &buffer,
 655                position,
 656                PredictEditsResponse {
 657                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 658                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 659a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 660And maybe a short line
 661Then a few lines
 662[a third completion]
 663and then another
 664{EDITABLE_REGION_END_MARKER}
 665                        "#),
 666                },
 667                cx,
 668            ),
 669            self.fake_completion(
 670                project,
 671                &buffer,
 672                position,
 673                PredictEditsResponse {
 674                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 675                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 676a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 677And maybe a short line
 678and then another
 679[fourth completion example]
 680{EDITABLE_REGION_END_MARKER}
 681                        "#),
 682                },
 683                cx,
 684            ),
 685            self.fake_completion(
 686                project,
 687                &buffer,
 688                position,
 689                PredictEditsResponse {
 690                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 691                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 692a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 693And maybe a short line
 694Then a few lines
 695and then another
 696[fifth and final completion]
 697{EDITABLE_REGION_END_MARKER}
 698                        "#),
 699                },
 700                cx,
 701            ),
 702        ];
 703
 704        cx.spawn(async move |zeta, cx| {
 705            for task in completion_tasks {
 706                task.await.unwrap();
 707            }
 708
 709            zeta.update(cx, |zeta, _cx| {
 710                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 711                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 712            })
 713            .ok();
 714        })
 715    }
 716
 717    #[cfg(any(test, feature = "test-support"))]
 718    pub fn fake_completion(
 719        &mut self,
 720        project: Option<&Entity<Project>>,
 721        buffer: &Entity<Buffer>,
 722        position: language::Anchor,
 723        response: PredictEditsResponse,
 724        cx: &mut Context<Self>,
 725    ) -> Task<Result<Option<EditPrediction>>> {
 726        use std::future::ready;
 727
 728        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
 729            ready(Ok((response, None)))
 730        })
 731    }
 732
 733    pub fn request_completion(
 734        &mut self,
 735        project: Option<&Entity<Project>>,
 736        buffer: &Entity<Buffer>,
 737        position: language::Anchor,
 738        can_collect_data: bool,
 739        cx: &mut Context<Self>,
 740    ) -> Task<Result<Option<EditPrediction>>> {
 741        self.request_completion_impl(
 742            self.workspace.upgrade(),
 743            project,
 744            buffer,
 745            position,
 746            can_collect_data,
 747            cx,
 748            Self::perform_predict_edits,
 749        )
 750    }
 751
 752    pub fn perform_predict_edits(
 753        params: PerformPredictEditsParams,
 754    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 755        async move {
 756            let PerformPredictEditsParams {
 757                client,
 758                llm_token,
 759                app_version,
 760                body,
 761                ..
 762            } = params;
 763
 764            let http_client = client.http_client();
 765            let mut token = llm_token.acquire(&client).await?;
 766            let mut did_retry = false;
 767
 768            loop {
 769                let request_builder = http_client::Request::builder().method(Method::POST);
 770                let request_builder =
 771                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 772                        request_builder.uri(predict_edits_url)
 773                    } else {
 774                        request_builder.uri(
 775                            http_client
 776                                .build_zed_llm_url("/predict_edits/v2", &[])?
 777                                .as_ref(),
 778                        )
 779                    };
 780                let request = request_builder
 781                    .header("Content-Type", "application/json")
 782                    .header("Authorization", format!("Bearer {}", token))
 783                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 784                    .body(serde_json::to_string(&body)?.into())?;
 785
 786                let mut response = http_client.send(request).await?;
 787
 788                if let Some(minimum_required_version) = response
 789                    .headers()
 790                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 791                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 792                {
 793                    anyhow::ensure!(
 794                        app_version >= minimum_required_version,
 795                        ZedUpdateRequiredError {
 796                            minimum_version: minimum_required_version
 797                        }
 798                    );
 799                }
 800
 801                if response.status().is_success() {
 802                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 803
 804                    let mut body = String::new();
 805                    response.body_mut().read_to_string(&mut body).await?;
 806                    return Ok((serde_json::from_str(&body)?, usage));
 807                } else if !did_retry
 808                    && response
 809                        .headers()
 810                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 811                        .is_some()
 812                {
 813                    did_retry = true;
 814                    token = llm_token.refresh(&client).await?;
 815                } else {
 816                    let mut body = String::new();
 817                    response.body_mut().read_to_string(&mut body).await?;
 818                    anyhow::bail!(
 819                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 820                        response.status(),
 821                        body
 822                    );
 823                }
 824            }
 825        }
 826    }
 827
 828    fn accept_edit_prediction(
 829        &mut self,
 830        request_id: EditPredictionId,
 831        cx: &mut Context<Self>,
 832    ) -> Task<Result<()>> {
 833        let client = self.client.clone();
 834        let llm_token = self.llm_token.clone();
 835        let app_version = AppVersion::global(cx);
 836        cx.spawn(async move |this, cx| {
 837            let http_client = client.http_client();
 838            let mut response = llm_token_retry(&llm_token, &client, |token| {
 839                let request_builder = http_client::Request::builder().method(Method::POST);
 840                let request_builder =
 841                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 842                        request_builder.uri(accept_prediction_url)
 843                    } else {
 844                        request_builder.uri(
 845                            http_client
 846                                .build_zed_llm_url("/predict_edits/accept", &[])?
 847                                .as_ref(),
 848                        )
 849                    };
 850                Ok(request_builder
 851                    .header("Content-Type", "application/json")
 852                    .header("Authorization", format!("Bearer {}", token))
 853                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 854                    .body(
 855                        serde_json::to_string(&AcceptEditPredictionBody {
 856                            request_id: request_id.0,
 857                        })?
 858                        .into(),
 859                    )?)
 860            })
 861            .await?;
 862
 863            if let Some(minimum_required_version) = response
 864                .headers()
 865                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 866                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 867                && app_version < minimum_required_version
 868            {
 869                return Err(anyhow!(ZedUpdateRequiredError {
 870                    minimum_version: minimum_required_version
 871                }));
 872            }
 873
 874            if response.status().is_success() {
 875                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 876                    this.update(cx, |this, cx| {
 877                        this.user_store.update(cx, |user_store, cx| {
 878                            user_store.update_edit_prediction_usage(usage, cx);
 879                        });
 880                    })?;
 881                }
 882
 883                Ok(())
 884            } else {
 885                let mut body = String::new();
 886                response.body_mut().read_to_string(&mut body).await?;
 887                Err(anyhow!(
 888                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 889                    response.status(),
 890                    body
 891                ))
 892            }
 893        })
 894    }
 895
 896    fn process_completion_response(
 897        prediction_response: PredictEditsResponse,
 898        buffer: Entity<Buffer>,
 899        snapshot: &BufferSnapshot,
 900        editable_range: Range<usize>,
 901        cursor_offset: usize,
 902        path: Arc<Path>,
 903        input_outline: String,
 904        input_events: String,
 905        input_excerpt: String,
 906        buffer_snapshotted_at: Instant,
 907        cx: &AsyncApp,
 908    ) -> Task<Result<Option<EditPrediction>>> {
 909        let snapshot = snapshot.clone();
 910        let request_id = prediction_response.request_id;
 911        let output_excerpt = prediction_response.output_excerpt;
 912        cx.spawn(async move |cx| {
 913            let output_excerpt: Arc<str> = output_excerpt.into();
 914
 915            let edits: Arc<[(Range<Anchor>, String)]> = cx
 916                .background_spawn({
 917                    let output_excerpt = output_excerpt.clone();
 918                    let editable_range = editable_range.clone();
 919                    let snapshot = snapshot.clone();
 920                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 921                })
 922                .await?
 923                .into();
 924
 925            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 926                let edits = edits.clone();
 927                |buffer, cx| {
 928                    let new_snapshot = buffer.snapshot();
 929                    let edits: Arc<[(Range<Anchor>, String)]> =
 930                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 931                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 932                }
 933            })?
 934            else {
 935                return anyhow::Ok(None);
 936            };
 937
 938            let edit_preview = edit_preview.await;
 939
 940            Ok(Some(EditPrediction {
 941                id: EditPredictionId(request_id),
 942                path,
 943                excerpt_range: editable_range,
 944                cursor_offset,
 945                edits,
 946                edit_preview,
 947                snapshot,
 948                input_outline: input_outline.into(),
 949                input_events: input_events.into(),
 950                input_excerpt: input_excerpt.into(),
 951                output_excerpt,
 952                buffer_snapshotted_at,
 953                response_received_at: Instant::now(),
 954            }))
 955        })
 956    }
 957
 958    fn parse_edits(
 959        output_excerpt: Arc<str>,
 960        editable_range: Range<usize>,
 961        snapshot: &BufferSnapshot,
 962    ) -> Result<Vec<(Range<Anchor>, String)>> {
 963        let content = output_excerpt.replace(CURSOR_MARKER, "");
 964
 965        let start_markers = content
 966            .match_indices(EDITABLE_REGION_START_MARKER)
 967            .collect::<Vec<_>>();
 968        anyhow::ensure!(
 969            start_markers.len() == 1,
 970            "expected exactly one start marker, found {}",
 971            start_markers.len()
 972        );
 973
 974        let end_markers = content
 975            .match_indices(EDITABLE_REGION_END_MARKER)
 976            .collect::<Vec<_>>();
 977        anyhow::ensure!(
 978            end_markers.len() == 1,
 979            "expected exactly one end marker, found {}",
 980            end_markers.len()
 981        );
 982
 983        let sof_markers = content
 984            .match_indices(START_OF_FILE_MARKER)
 985            .collect::<Vec<_>>();
 986        anyhow::ensure!(
 987            sof_markers.len() <= 1,
 988            "expected at most one start-of-file marker, found {}",
 989            sof_markers.len()
 990        );
 991
 992        let codefence_start = start_markers[0].0;
 993        let content = &content[codefence_start..];
 994
 995        let newline_ix = content.find('\n').context("could not find newline")?;
 996        let content = &content[newline_ix + 1..];
 997
 998        let codefence_end = content
 999            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
1000            .context("could not find end marker")?;
1001        let new_text = &content[..codefence_end];
1002
1003        let old_text = snapshot
1004            .text_for_range(editable_range.clone())
1005            .collect::<String>();
1006
1007        Ok(Self::compute_edits(
1008            old_text,
1009            new_text,
1010            editable_range.start,
1011            snapshot,
1012        ))
1013    }
1014
1015    pub fn compute_edits(
1016        old_text: String,
1017        new_text: &str,
1018        offset: usize,
1019        snapshot: &BufferSnapshot,
1020    ) -> Vec<(Range<Anchor>, String)> {
1021        text_diff(&old_text, new_text)
1022            .into_iter()
1023            .map(|(mut old_range, new_text)| {
1024                old_range.start += offset;
1025                old_range.end += offset;
1026
1027                let prefix_len = common_prefix(
1028                    snapshot.chars_for_range(old_range.clone()),
1029                    new_text.chars(),
1030                );
1031                old_range.start += prefix_len;
1032
1033                let suffix_len = common_prefix(
1034                    snapshot.reversed_chars_for_range(old_range.clone()),
1035                    new_text[prefix_len..].chars().rev(),
1036                );
1037                old_range.end = old_range.end.saturating_sub(suffix_len);
1038
1039                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1040                let range = if old_range.is_empty() {
1041                    let anchor = snapshot.anchor_after(old_range.start);
1042                    anchor..anchor
1043                } else {
1044                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1045                };
1046                (range, new_text)
1047            })
1048            .collect()
1049    }
1050
1051    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1052        self.rated_completions.contains(&completion_id)
1053    }
1054
1055    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1056        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1057            let completion = self.shown_completions.pop_back().unwrap();
1058            self.rated_completions.remove(&completion.id);
1059        }
1060        self.shown_completions.push_front(completion.clone());
1061        cx.notify();
1062    }
1063
1064    pub fn rate_completion(
1065        &mut self,
1066        completion: &EditPrediction,
1067        rating: EditPredictionRating,
1068        feedback: String,
1069        cx: &mut Context<Self>,
1070    ) {
1071        self.rated_completions.insert(completion.id);
1072        telemetry::event!(
1073            "Edit Prediction Rated",
1074            rating,
1075            input_events = completion.input_events,
1076            input_excerpt = completion.input_excerpt,
1077            input_outline = completion.input_outline,
1078            output_excerpt = completion.output_excerpt,
1079            feedback
1080        );
1081        self.client.telemetry().flush_events().detach();
1082        cx.notify();
1083    }
1084
1085    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1086        self.shown_completions.iter()
1087    }
1088
1089    pub fn shown_completions_len(&self) -> usize {
1090        self.shown_completions.len()
1091    }
1092
1093    fn report_changes_for_buffer(
1094        &mut self,
1095        buffer: &Entity<Buffer>,
1096        cx: &mut Context<Self>,
1097    ) -> BufferSnapshot {
1098        self.register_buffer(buffer, cx);
1099
1100        let registered_buffer = self
1101            .registered_buffers
1102            .get_mut(&buffer.entity_id())
1103            .unwrap();
1104        let new_snapshot = buffer.read(cx).snapshot();
1105
1106        if new_snapshot.version != registered_buffer.snapshot.version {
1107            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1108            self.push_event(Event::BufferChange {
1109                old_snapshot,
1110                new_snapshot: new_snapshot.clone(),
1111                timestamp: Instant::now(),
1112            });
1113        }
1114
1115        new_snapshot
1116    }
1117
1118    fn load_data_collection_choices() -> DataCollectionChoice {
1119        let choice = KEY_VALUE_STORE
1120            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1121            .log_err()
1122            .flatten();
1123
1124        match choice.as_deref() {
1125            Some("true") => DataCollectionChoice::Enabled,
1126            Some("false") => DataCollectionChoice::Disabled,
1127            Some(_) => {
1128                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1129                DataCollectionChoice::NotAnswered
1130            }
1131            None => DataCollectionChoice::NotAnswered,
1132        }
1133    }
1134
1135    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1136        let now = Instant::now();
1137        if let Some(existing_ix) = self
1138            .recent_project_entries
1139            .iter()
1140            .rposition(|(id, _)| *id == project_entry_id)
1141        {
1142            self.recent_project_entries.remove(existing_ix);
1143        }
1144        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1145            self.recent_project_entries.pop_front();
1146        }
1147        self.recent_project_entries
1148            .push_back((project_entry_id, now));
1149    }
1150
1151    fn recent_files(
1152        &mut self,
1153        now: &Instant,
1154        repository: &Repository,
1155        cx: &Context<Self>,
1156    ) -> Vec<PredictEditsRecentFile> {
1157        let Ok(project) = self
1158            .workspace
1159            .read_with(cx, |workspace, _cx| workspace.project().clone())
1160        else {
1161            return Vec::new();
1162        };
1163        let mut results = Vec::new();
1164        for ix in (0..self.recent_project_entries.len()).rev() {
1165            let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1166            if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1167                && let worktree = worktree.read(cx)
1168                && let Some(entry) = worktree.entry_for_id(*entry_id)
1169                && entry.is_file()
1170                && entry.is_created()
1171                && !entry.is_ignored
1172                && !entry.is_private
1173                && !entry.is_external
1174                && !entry.is_fifo
1175            {
1176                let project_path = ProjectPath {
1177                    worktree_id: worktree.id(),
1178                    path: entry.path.clone(),
1179                };
1180                let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1181                else {
1182                    // entry not removed since queries involving other repositories might occur later
1183                    continue;
1184                };
1185                let Some(repo_path_str) = repo_path.to_str() else {
1186                    // paths may not be valid UTF-8
1187                    self.recent_project_entries.remove(ix);
1188                    continue;
1189                };
1190                if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1191                    self.recent_project_entries.remove(ix);
1192                    continue;
1193                }
1194                if let Some(file_status) = repository.status_for_path(&repo_path) {
1195                    if file_status.is_ignored() || file_status.is_untracked() {
1196                        // entry not removed because it may belong to a nested repository
1197                        continue;
1198                    }
1199                }
1200                let Ok(active_to_now_ms) =
1201                    now.duration_since(*last_active_at).as_millis().try_into()
1202                else {
1203                    self.recent_project_entries.remove(ix);
1204                    continue;
1205                };
1206                results.push(PredictEditsRecentFile {
1207                    repo_path: repo_path_str.to_string(),
1208                    active_to_now_ms,
1209                });
1210            } else {
1211                self.recent_project_entries.remove(ix);
1212            }
1213        }
1214        results
1215    }
1216}
1217
1218pub struct PerformPredictEditsParams {
1219    pub client: Arc<Client>,
1220    pub llm_token: LlmApiToken,
1221    pub app_version: SemanticVersion,
1222    pub body: PredictEditsBody,
1223}
1224
1225#[derive(Error, Debug)]
1226#[error(
1227    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1228)]
1229pub struct ZedUpdateRequiredError {
1230    minimum_version: SemanticVersion,
1231}
1232
1233fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1234    a.zip(b)
1235        .take_while(|(a, b)| a == b)
1236        .map(|(a, _)| a.len_utf8())
1237        .sum()
1238}
1239
1240fn git_repository_for_file(
1241    project: &Entity<Project>,
1242    project_path: &ProjectPath,
1243    cx: &App,
1244) -> Option<Entity<Repository>> {
1245    let git_store = project.read(cx).git_store().read(cx);
1246    git_store
1247        .repository_and_path_for_project_path(project_path, cx)
1248        .map(|(repo, _repo_path)| repo)
1249}
1250
1251fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
1252    let head_sha = repository
1253        .head_commit
1254        .as_ref()
1255        .map(|head_commit| head_commit.sha.to_string());
1256    let remote_origin_url = repository.remote_origin_url.clone();
1257    let remote_upstream_url = repository.remote_upstream_url.clone();
1258    if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1259        return None;
1260    }
1261    Some(PredictEditsGitInfo {
1262        head_sha,
1263        remote_origin_url,
1264        remote_upstream_url,
1265    })
1266}
1267
1268pub struct GatherContextOutput {
1269    pub body: PredictEditsBody,
1270    pub editable_range: Range<usize>,
1271}
1272
1273pub fn gather_context(
1274    project: Option<&Entity<Project>>,
1275    full_path_str: String,
1276    snapshot: &BufferSnapshot,
1277    cursor_point: language::Point,
1278    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1279    can_collect_data: bool,
1280    git_info: Option<PredictEditsGitInfo>,
1281    recent_files: Option<Vec<PredictEditsRecentFile>>,
1282    cx: &App,
1283) -> Task<Result<GatherContextOutput>> {
1284    let local_lsp_store =
1285        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1286    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1287        if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
1288            snapshot
1289                .diagnostic_groups(None)
1290                .into_iter()
1291                .filter_map(|(language_server_id, diagnostic_group)| {
1292                    let language_server =
1293                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1294                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1295                    let language_server_name = language_server.name().to_string();
1296                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1297                    Some((language_server_name, serialized))
1298                })
1299                .collect::<Vec<_>>()
1300        } else {
1301            Vec::new()
1302        };
1303
1304    cx.background_spawn({
1305        let snapshot = snapshot.clone();
1306        async move {
1307            let diagnostic_groups = if diagnostic_groups.is_empty()
1308                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1309            {
1310                None
1311            } else {
1312                Some(diagnostic_groups)
1313            };
1314
1315            let input_excerpt = excerpt_for_cursor_position(
1316                cursor_point,
1317                &full_path_str,
1318                &snapshot,
1319                MAX_REWRITE_TOKENS,
1320                MAX_CONTEXT_TOKENS,
1321            );
1322            let input_events = make_events_prompt();
1323            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1324
1325            let body = PredictEditsBody {
1326                input_events,
1327                input_excerpt: input_excerpt.prompt,
1328                can_collect_data,
1329                diagnostic_groups,
1330                git_info,
1331                outline: None,
1332                speculated_output: None,
1333                recent_files,
1334            };
1335
1336            Ok(GatherContextOutput {
1337                body,
1338                editable_range,
1339            })
1340        }
1341    })
1342}
1343
1344fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1345    let mut result = String::new();
1346    for event in events.iter().rev() {
1347        let event_string = event.to_prompt();
1348        let event_tokens = tokens_for_bytes(event_string.len());
1349        if event_tokens > remaining_tokens {
1350            break;
1351        }
1352
1353        if !result.is_empty() {
1354            result.insert_str(0, "\n\n");
1355        }
1356        result.insert_str(0, &event_string);
1357        remaining_tokens -= event_tokens;
1358    }
1359    result
1360}
1361
1362struct RegisteredBuffer {
1363    snapshot: BufferSnapshot,
1364    _subscriptions: [gpui::Subscription; 2],
1365}
1366
1367#[derive(Clone)]
1368pub enum Event {
1369    BufferChange {
1370        old_snapshot: BufferSnapshot,
1371        new_snapshot: BufferSnapshot,
1372        timestamp: Instant,
1373    },
1374}
1375
1376impl Event {
1377    fn to_prompt(&self) -> String {
1378        match self {
1379            Event::BufferChange {
1380                old_snapshot,
1381                new_snapshot,
1382                ..
1383            } => {
1384                let mut prompt = String::new();
1385
1386                let old_path = old_snapshot
1387                    .file()
1388                    .map(|f| f.path().as_ref())
1389                    .unwrap_or(Path::new("untitled"));
1390                let new_path = new_snapshot
1391                    .file()
1392                    .map(|f| f.path().as_ref())
1393                    .unwrap_or(Path::new("untitled"));
1394                if old_path != new_path {
1395                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1396                }
1397
1398                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1399                if !diff.is_empty() {
1400                    write!(
1401                        prompt,
1402                        "User edited {:?}:\n```diff\n{}\n```",
1403                        new_path, diff
1404                    )
1405                    .unwrap();
1406                }
1407
1408                prompt
1409            }
1410        }
1411    }
1412}
1413
1414#[derive(Debug, Clone)]
1415struct CurrentEditPrediction {
1416    buffer_id: EntityId,
1417    completion: EditPrediction,
1418}
1419
1420impl CurrentEditPrediction {
1421    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1422        if self.buffer_id != old_completion.buffer_id {
1423            return true;
1424        }
1425
1426        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1427            return true;
1428        };
1429        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1430            return false;
1431        };
1432
1433        if old_edits.len() == 1 && new_edits.len() == 1 {
1434            let (old_range, old_text) = &old_edits[0];
1435            let (new_range, new_text) = &new_edits[0];
1436            new_range == old_range && new_text.starts_with(old_text)
1437        } else {
1438            true
1439        }
1440    }
1441}
1442
1443struct PendingCompletion {
1444    id: usize,
1445    _task: Task<()>,
1446}
1447
1448#[derive(Debug, Clone, Copy)]
1449pub enum DataCollectionChoice {
1450    NotAnswered,
1451    Enabled,
1452    Disabled,
1453}
1454
1455impl DataCollectionChoice {
1456    pub fn is_enabled(self) -> bool {
1457        match self {
1458            Self::Enabled => true,
1459            Self::NotAnswered | Self::Disabled => false,
1460        }
1461    }
1462
1463    pub fn is_answered(self) -> bool {
1464        match self {
1465            Self::Enabled | Self::Disabled => true,
1466            Self::NotAnswered => false,
1467        }
1468    }
1469
1470    pub fn toggle(&self) -> DataCollectionChoice {
1471        match self {
1472            Self::Enabled => Self::Disabled,
1473            Self::Disabled => Self::Enabled,
1474            Self::NotAnswered => Self::Enabled,
1475        }
1476    }
1477}
1478
1479impl From<bool> for DataCollectionChoice {
1480    fn from(value: bool) -> Self {
1481        match value {
1482            true => DataCollectionChoice::Enabled,
1483            false => DataCollectionChoice::Disabled,
1484        }
1485    }
1486}
1487
1488pub struct ProviderDataCollection {
1489    /// When set to None, data collection is not possible in the provider buffer
1490    choice: Option<Entity<DataCollectionChoice>>,
1491    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1492}
1493
1494impl ProviderDataCollection {
1495    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1496        let choice_and_watcher = buffer.and_then(|buffer| {
1497            let file = buffer.read(cx).file()?;
1498
1499            if !file.is_local() || file.is_private() {
1500                return None;
1501            }
1502
1503            let zeta = zeta.read(cx);
1504            let choice = zeta.data_collection_choice.clone();
1505
1506            let license_detection_watcher = zeta
1507                .license_detection_watchers
1508                .get(&file.worktree_id(cx))
1509                .cloned()?;
1510
1511            Some((choice, license_detection_watcher))
1512        });
1513
1514        if let Some((choice, watcher)) = choice_and_watcher {
1515            ProviderDataCollection {
1516                choice: Some(choice),
1517                license_detection_watcher: Some(watcher),
1518            }
1519        } else {
1520            ProviderDataCollection {
1521                choice: None,
1522                license_detection_watcher: None,
1523            }
1524        }
1525    }
1526
1527    pub fn can_collect_data(&self, cx: &App) -> bool {
1528        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1529    }
1530
1531    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1532        self.choice
1533            .as_ref()
1534            .is_some_and(|choice| choice.read(cx).is_enabled())
1535    }
1536
1537    fn is_project_open_source(&self) -> bool {
1538        self.license_detection_watcher
1539            .as_ref()
1540            .is_some_and(|watcher| watcher.is_project_open_source())
1541    }
1542
1543    pub fn toggle(&mut self, cx: &mut App) {
1544        if let Some(choice) = self.choice.as_mut() {
1545            let new_choice = choice.update(cx, |choice, _cx| {
1546                let new_choice = choice.toggle();
1547                *choice = new_choice;
1548                new_choice
1549            });
1550
1551            db::write_and_log(cx, move || {
1552                KEY_VALUE_STORE.write_kvp(
1553                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1554                    new_choice.is_enabled().to_string(),
1555                )
1556            });
1557        }
1558    }
1559}
1560
1561async fn llm_token_retry(
1562    llm_token: &LlmApiToken,
1563    client: &Arc<Client>,
1564    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1565) -> Result<Response<AsyncBody>> {
1566    let mut did_retry = false;
1567    let http_client = client.http_client();
1568    let mut token = llm_token.acquire(client).await?;
1569    loop {
1570        let request = build_request(token.clone())?;
1571        let response = http_client.send(request).await?;
1572
1573        if !did_retry
1574            && !response.status().is_success()
1575            && response
1576                .headers()
1577                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1578                .is_some()
1579        {
1580            did_retry = true;
1581            token = llm_token.refresh(client).await?;
1582            continue;
1583        }
1584
1585        return Ok(response);
1586    }
1587}
1588
1589pub struct ZetaEditPredictionProvider {
1590    zeta: Entity<Zeta>,
1591    pending_completions: ArrayVec<PendingCompletion, 2>,
1592    next_pending_completion_id: usize,
1593    current_completion: Option<CurrentEditPrediction>,
1594    /// None if this is entirely disabled for this provider
1595    provider_data_collection: ProviderDataCollection,
1596    last_request_timestamp: Instant,
1597}
1598
1599impl ZetaEditPredictionProvider {
1600    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1601
1602    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1603        Self {
1604            zeta,
1605            pending_completions: ArrayVec::new(),
1606            next_pending_completion_id: 0,
1607            current_completion: None,
1608            provider_data_collection,
1609            last_request_timestamp: Instant::now(),
1610        }
1611    }
1612}
1613
1614impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1615    fn name() -> &'static str {
1616        "zed-predict"
1617    }
1618
1619    fn display_name() -> &'static str {
1620        "Zed's Edit Predictions"
1621    }
1622
1623    fn show_completions_in_menu() -> bool {
1624        true
1625    }
1626
1627    fn show_tab_accept_marker() -> bool {
1628        true
1629    }
1630
1631    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1632        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1633
1634        if self.provider_data_collection.is_data_collection_enabled(cx) {
1635            DataCollectionState::Enabled {
1636                is_project_open_source,
1637            }
1638        } else {
1639            DataCollectionState::Disabled {
1640                is_project_open_source,
1641            }
1642        }
1643    }
1644
1645    fn toggle_data_collection(&mut self, cx: &mut App) {
1646        self.provider_data_collection.toggle(cx);
1647    }
1648
1649    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1650        self.zeta.read(cx).usage(cx)
1651    }
1652
1653    fn is_enabled(
1654        &self,
1655        _buffer: &Entity<Buffer>,
1656        _cursor_position: language::Anchor,
1657        _cx: &App,
1658    ) -> bool {
1659        true
1660    }
1661    fn is_refreshing(&self) -> bool {
1662        !self.pending_completions.is_empty()
1663    }
1664
1665    fn refresh(
1666        &mut self,
1667        project: Option<Entity<Project>>,
1668        buffer: Entity<Buffer>,
1669        position: language::Anchor,
1670        _debounce: bool,
1671        cx: &mut Context<Self>,
1672    ) {
1673        if self.zeta.read(cx).update_required {
1674            return;
1675        }
1676
1677        if self
1678            .zeta
1679            .read(cx)
1680            .user_store
1681            .read_with(cx, |user_store, _cx| {
1682                user_store.account_too_young() || user_store.has_overdue_invoices()
1683            })
1684        {
1685            return;
1686        }
1687
1688        if let Some(current_completion) = self.current_completion.as_ref() {
1689            let snapshot = buffer.read(cx).snapshot();
1690            if current_completion
1691                .completion
1692                .interpolate(&snapshot)
1693                .is_some()
1694            {
1695                return;
1696            }
1697        }
1698
1699        let pending_completion_id = self.next_pending_completion_id;
1700        self.next_pending_completion_id += 1;
1701        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1702        let last_request_timestamp = self.last_request_timestamp;
1703
1704        let task = cx.spawn(async move |this, cx| {
1705            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1706                .checked_duration_since(Instant::now())
1707            {
1708                cx.background_executor().timer(timeout).await;
1709            }
1710
1711            let completion_request = this.update(cx, |this, cx| {
1712                this.last_request_timestamp = Instant::now();
1713                this.zeta.update(cx, |zeta, cx| {
1714                    zeta.request_completion(
1715                        project.as_ref(),
1716                        &buffer,
1717                        position,
1718                        can_collect_data,
1719                        cx,
1720                    )
1721                })
1722            });
1723
1724            let completion = match completion_request {
1725                Ok(completion_request) => {
1726                    let completion_request = completion_request.await;
1727                    completion_request.map(|c| {
1728                        c.map(|completion| CurrentEditPrediction {
1729                            buffer_id: buffer.entity_id(),
1730                            completion,
1731                        })
1732                    })
1733                }
1734                Err(error) => Err(error),
1735            };
1736            let Some(new_completion) = completion
1737                .context("edit prediction failed")
1738                .log_err()
1739                .flatten()
1740            else {
1741                this.update(cx, |this, cx| {
1742                    if this.pending_completions[0].id == pending_completion_id {
1743                        this.pending_completions.remove(0);
1744                    } else {
1745                        this.pending_completions.clear();
1746                    }
1747
1748                    cx.notify();
1749                })
1750                .ok();
1751                return;
1752            };
1753
1754            this.update(cx, |this, cx| {
1755                if this.pending_completions[0].id == pending_completion_id {
1756                    this.pending_completions.remove(0);
1757                } else {
1758                    this.pending_completions.clear();
1759                }
1760
1761                if let Some(old_completion) = this.current_completion.as_ref() {
1762                    let snapshot = buffer.read(cx).snapshot();
1763                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1764                        this.zeta.update(cx, |zeta, cx| {
1765                            zeta.completion_shown(&new_completion.completion, cx);
1766                        });
1767                        this.current_completion = Some(new_completion);
1768                    }
1769                } else {
1770                    this.zeta.update(cx, |zeta, cx| {
1771                        zeta.completion_shown(&new_completion.completion, cx);
1772                    });
1773                    this.current_completion = Some(new_completion);
1774                }
1775
1776                cx.notify();
1777            })
1778            .ok();
1779        });
1780
1781        // We always maintain at most two pending completions. When we already
1782        // have two, we replace the newest one.
1783        if self.pending_completions.len() <= 1 {
1784            self.pending_completions.push(PendingCompletion {
1785                id: pending_completion_id,
1786                _task: task,
1787            });
1788        } else if self.pending_completions.len() == 2 {
1789            self.pending_completions.pop();
1790            self.pending_completions.push(PendingCompletion {
1791                id: pending_completion_id,
1792                _task: task,
1793            });
1794        }
1795    }
1796
1797    fn cycle(
1798        &mut self,
1799        _buffer: Entity<Buffer>,
1800        _cursor_position: language::Anchor,
1801        _direction: edit_prediction::Direction,
1802        _cx: &mut Context<Self>,
1803    ) {
1804        // Right now we don't support cycling.
1805    }
1806
1807    fn accept(&mut self, cx: &mut Context<Self>) {
1808        let completion_id = self
1809            .current_completion
1810            .as_ref()
1811            .map(|completion| completion.completion.id);
1812        if let Some(completion_id) = completion_id {
1813            self.zeta
1814                .update(cx, |zeta, cx| {
1815                    zeta.accept_edit_prediction(completion_id, cx)
1816                })
1817                .detach();
1818        }
1819        self.pending_completions.clear();
1820    }
1821
1822    fn discard(&mut self, _cx: &mut Context<Self>) {
1823        self.pending_completions.clear();
1824        self.current_completion.take();
1825    }
1826
1827    fn suggest(
1828        &mut self,
1829        buffer: &Entity<Buffer>,
1830        cursor_position: language::Anchor,
1831        cx: &mut Context<Self>,
1832    ) -> Option<edit_prediction::EditPrediction> {
1833        let CurrentEditPrediction {
1834            buffer_id,
1835            completion,
1836            ..
1837        } = self.current_completion.as_mut()?;
1838
1839        // Invalidate previous completion if it was generated for a different buffer.
1840        if *buffer_id != buffer.entity_id() {
1841            self.current_completion.take();
1842            return None;
1843        }
1844
1845        let buffer = buffer.read(cx);
1846        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1847            self.current_completion.take();
1848            return None;
1849        };
1850
1851        let cursor_row = cursor_position.to_point(buffer).row;
1852        let (closest_edit_ix, (closest_edit_range, _)) =
1853            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1854                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1855                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1856                cmp::min(distance_from_start, distance_from_end)
1857            })?;
1858
1859        let mut edit_start_ix = closest_edit_ix;
1860        for (range, _) in edits[..edit_start_ix].iter().rev() {
1861            let distance_from_closest_edit =
1862                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1863            if distance_from_closest_edit <= 1 {
1864                edit_start_ix -= 1;
1865            } else {
1866                break;
1867            }
1868        }
1869
1870        let mut edit_end_ix = closest_edit_ix + 1;
1871        for (range, _) in &edits[edit_end_ix..] {
1872            let distance_from_closest_edit =
1873                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1874            if distance_from_closest_edit <= 1 {
1875                edit_end_ix += 1;
1876            } else {
1877                break;
1878            }
1879        }
1880
1881        Some(edit_prediction::EditPrediction {
1882            id: Some(completion.id.to_string().into()),
1883            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1884            edit_preview: Some(completion.edit_preview.clone()),
1885        })
1886    }
1887}
1888
1889fn tokens_for_bytes(bytes: usize) -> usize {
1890    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1891    /// intentionally low to err on the side of underestimating limits.
1892    const BYTES_PER_TOKEN_GUESS: usize = 3;
1893    bytes / BYTES_PER_TOKEN_GUESS
1894}
1895
1896#[cfg(test)]
1897mod tests {
1898    use client::UserStore;
1899    use client::test::FakeServer;
1900    use clock::FakeSystemClock;
1901    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1902    use gpui::TestAppContext;
1903    use http_client::FakeHttpClient;
1904    use indoc::indoc;
1905    use language::Point;
1906    use settings::SettingsStore;
1907
1908    use super::*;
1909
1910    #[gpui::test]
1911    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1912        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1913        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1914            to_completion_edits(
1915                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1916                &buffer,
1917                cx,
1918            )
1919            .into()
1920        });
1921
1922        let edit_preview = cx
1923            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1924            .await;
1925
1926        let completion = EditPrediction {
1927            edits,
1928            edit_preview,
1929            path: Path::new("").into(),
1930            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1931            id: EditPredictionId(Uuid::new_v4()),
1932            excerpt_range: 0..0,
1933            cursor_offset: 0,
1934            input_outline: "".into(),
1935            input_events: "".into(),
1936            input_excerpt: "".into(),
1937            output_excerpt: "".into(),
1938            buffer_snapshotted_at: Instant::now(),
1939            response_received_at: Instant::now(),
1940        };
1941
1942        cx.update(|cx| {
1943            assert_eq!(
1944                from_completion_edits(
1945                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1946                    &buffer,
1947                    cx
1948                ),
1949                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1950            );
1951
1952            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1953            assert_eq!(
1954                from_completion_edits(
1955                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1956                    &buffer,
1957                    cx
1958                ),
1959                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1960            );
1961
1962            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1963            assert_eq!(
1964                from_completion_edits(
1965                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1966                    &buffer,
1967                    cx
1968                ),
1969                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1970            );
1971
1972            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1973            assert_eq!(
1974                from_completion_edits(
1975                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1976                    &buffer,
1977                    cx
1978                ),
1979                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1980            );
1981
1982            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1983            assert_eq!(
1984                from_completion_edits(
1985                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1986                    &buffer,
1987                    cx
1988                ),
1989                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1990            );
1991
1992            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1993            assert_eq!(
1994                from_completion_edits(
1995                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1996                    &buffer,
1997                    cx
1998                ),
1999                vec![(9..11, "".to_string())]
2000            );
2001
2002            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2003            assert_eq!(
2004                from_completion_edits(
2005                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2006                    &buffer,
2007                    cx
2008                ),
2009                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2010            );
2011
2012            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2013            assert_eq!(
2014                from_completion_edits(
2015                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2016                    &buffer,
2017                    cx
2018                ),
2019                vec![(4..4, "M".to_string())]
2020            );
2021
2022            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2023            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2024        })
2025    }
2026
2027    #[gpui::test]
2028    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2029        cx.update(|cx| {
2030            let settings_store = SettingsStore::test(cx);
2031            cx.set_global(settings_store);
2032            client::init_settings(cx);
2033        });
2034
2035        let edits = edits_for_prediction(
2036            indoc! {"
2037                fn main() {
2038                    let word_1 = \"lorem\";
2039                    let range = word.len()..word.len();
2040                }
2041            "},
2042            indoc! {"
2043                <|editable_region_start|>
2044                fn main() {
2045                    let word_1 = \"lorem\";
2046                    let range = word_1.len()..word_1.len();
2047                }
2048
2049                <|editable_region_end|>
2050            "},
2051            cx,
2052        )
2053        .await;
2054        assert_eq!(
2055            edits,
2056            [
2057                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2058                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2059            ]
2060        );
2061
2062        let edits = edits_for_prediction(
2063            indoc! {"
2064                fn main() {
2065                    let story = \"the quick\"
2066                }
2067            "},
2068            indoc! {"
2069                <|editable_region_start|>
2070                fn main() {
2071                    let story = \"the quick brown fox jumps over the lazy dog\";
2072                }
2073
2074                <|editable_region_end|>
2075            "},
2076            cx,
2077        )
2078        .await;
2079        assert_eq!(
2080            edits,
2081            [
2082                (
2083                    Point::new(1, 26)..Point::new(1, 26),
2084                    " brown fox jumps over the lazy dog".to_string()
2085                ),
2086                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2087            ]
2088        );
2089    }
2090
2091    #[gpui::test]
2092    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2093        cx.update(|cx| {
2094            let settings_store = SettingsStore::test(cx);
2095            cx.set_global(settings_store);
2096            client::init_settings(cx);
2097        });
2098
2099        let buffer_content = "lorem\n";
2100        let completion_response = indoc! {"
2101            ```animals.js
2102            <|start_of_file|>
2103            <|editable_region_start|>
2104            lorem
2105            ipsum
2106            <|editable_region_end|>
2107            ```"};
2108
2109        let http_client = FakeHttpClient::create(move |req| async move {
2110            match (req.method(), req.uri().path()) {
2111                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2112                    .status(200)
2113                    .body(
2114                        serde_json::to_string(&CreateLlmTokenResponse {
2115                            token: LlmToken("the-llm-token".to_string()),
2116                        })
2117                        .unwrap()
2118                        .into(),
2119                    )
2120                    .unwrap()),
2121                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2122                    .status(200)
2123                    .body(
2124                        serde_json::to_string(&PredictEditsResponse {
2125                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2126                                .unwrap(),
2127                            output_excerpt: completion_response.to_string(),
2128                        })
2129                        .unwrap()
2130                        .into(),
2131                    )
2132                    .unwrap()),
2133                _ => Ok(http_client::Response::builder()
2134                    .status(404)
2135                    .body("Not Found".into())
2136                    .unwrap()),
2137            }
2138        });
2139
2140        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2141        cx.update(|cx| {
2142            RefreshLlmTokenListener::register(client.clone(), cx);
2143        });
2144        // Construct the fake server to authenticate.
2145        let _server = FakeServer::for_client(42, &client, cx).await;
2146        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2147        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2148
2149        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2150        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2151        let completion_task = zeta.update(cx, |zeta, cx| {
2152            zeta.request_completion(None, &buffer, cursor, false, cx)
2153        });
2154
2155        let completion = completion_task.await.unwrap().unwrap();
2156        buffer.update(cx, |buffer, cx| {
2157            buffer.edit(completion.edits.iter().cloned(), None, cx)
2158        });
2159        assert_eq!(
2160            buffer.read_with(cx, |buffer, _| buffer.text()),
2161            "lorem\nipsum"
2162        );
2163    }
2164
2165    async fn edits_for_prediction(
2166        buffer_content: &str,
2167        completion_response: &str,
2168        cx: &mut TestAppContext,
2169    ) -> Vec<(Range<Point>, String)> {
2170        let completion_response = completion_response.to_string();
2171        let http_client = FakeHttpClient::create(move |req| {
2172            let completion = completion_response.clone();
2173            async move {
2174                match (req.method(), req.uri().path()) {
2175                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2176                        .status(200)
2177                        .body(
2178                            serde_json::to_string(&CreateLlmTokenResponse {
2179                                token: LlmToken("the-llm-token".to_string()),
2180                            })
2181                            .unwrap()
2182                            .into(),
2183                        )
2184                        .unwrap()),
2185                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2186                        .status(200)
2187                        .body(
2188                            serde_json::to_string(&PredictEditsResponse {
2189                                request_id: Uuid::new_v4(),
2190                                output_excerpt: completion,
2191                            })
2192                            .unwrap()
2193                            .into(),
2194                        )
2195                        .unwrap()),
2196                    _ => Ok(http_client::Response::builder()
2197                        .status(404)
2198                        .body("Not Found".into())
2199                        .unwrap()),
2200                }
2201            }
2202        });
2203
2204        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2205        cx.update(|cx| {
2206            RefreshLlmTokenListener::register(client.clone(), cx);
2207        });
2208        // Construct the fake server to authenticate.
2209        let _server = FakeServer::for_client(42, &client, cx).await;
2210        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2211        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2212
2213        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2214        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2215        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2216        let completion_task = zeta.update(cx, |zeta, cx| {
2217            zeta.request_completion(None, &buffer, cursor, false, cx)
2218        });
2219
2220        let completion = completion_task.await.unwrap().unwrap();
2221        completion
2222            .edits
2223            .iter()
2224            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2225            .collect::<Vec<_>>()
2226    }
2227
2228    fn to_completion_edits(
2229        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2230        buffer: &Entity<Buffer>,
2231        cx: &App,
2232    ) -> Vec<(Range<Anchor>, String)> {
2233        let buffer = buffer.read(cx);
2234        iterator
2235            .into_iter()
2236            .map(|(range, text)| {
2237                (
2238                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2239                    text,
2240                )
2241            })
2242            .collect()
2243    }
2244
2245    fn from_completion_edits(
2246        editor_edits: &[(Range<Anchor>, String)],
2247        buffer: &Entity<Buffer>,
2248        cx: &App,
2249    ) -> Vec<(Range<usize>, String)> {
2250        let buffer = buffer.read(cx);
2251        editor_edits
2252            .iter()
2253            .map(|(range, text)| {
2254                (
2255                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2256                    text.clone(),
2257                )
2258            })
2259            .collect()
2260    }
2261
2262    #[ctor::ctor]
2263    fn init_logger() {
2264        zlog::init_test();
2265    }
2266}