zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Maximum file path length to include in recent files list.
  79const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
  80
  81/// Maximum number of edit predictions to store for feedback.
  82const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  83
  84actions!(
  85    edit_prediction,
  86    [
  87        /// Clears the edit prediction history.
  88        ClearHistory
  89    ]
  90);
  91
  92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  93pub struct EditPredictionId(Uuid);
  94
  95impl From<EditPredictionId> for gpui::ElementId {
  96    fn from(value: EditPredictionId) -> Self {
  97        gpui::ElementId::Uuid(value.0)
  98    }
  99}
 100
 101impl std::fmt::Display for EditPredictionId {
 102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 103        write!(f, "{}", self.0)
 104    }
 105}
 106
 107struct ZedPredictUpsell;
 108
 109impl Dismissable for ZedPredictUpsell {
 110    const KEY: &'static str = "dismissed-edit-predict-upsell";
 111
 112    fn dismissed() -> bool {
 113        // To make this backwards compatible with older versions of Zed, we
 114        // check if the user has seen the previous Edit Prediction Onboarding
 115        // before, by checking the data collection choice which was written to
 116        // the database once the user clicked on "Accept and Enable"
 117        if KEY_VALUE_STORE
 118            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 119            .log_err()
 120            .is_some_and(|s| s.is_some())
 121        {
 122            return true;
 123        }
 124
 125        KEY_VALUE_STORE
 126            .read_kvp(Self::KEY)
 127            .log_err()
 128            .is_some_and(|s| s.is_some())
 129    }
 130}
 131
 132pub fn should_show_upsell_modal() -> bool {
 133    !ZedPredictUpsell::dismissed()
 134}
 135
 136#[derive(Clone)]
 137struct ZetaGlobal(Entity<Zeta>);
 138
 139impl Global for ZetaGlobal {}
 140
 141#[derive(Clone)]
 142pub struct EditPrediction {
 143    id: EditPredictionId,
 144    path: Arc<Path>,
 145    excerpt_range: Range<usize>,
 146    cursor_offset: usize,
 147    edits: Arc<[(Range<Anchor>, String)]>,
 148    snapshot: BufferSnapshot,
 149    edit_preview: EditPreview,
 150    input_outline: Arc<str>,
 151    input_events: Arc<str>,
 152    input_excerpt: Arc<str>,
 153    output_excerpt: Arc<str>,
 154    buffer_snapshotted_at: Instant,
 155    response_received_at: Instant,
 156}
 157
 158impl EditPrediction {
 159    fn latency(&self) -> Duration {
 160        self.response_received_at
 161            .duration_since(self.buffer_snapshotted_at)
 162    }
 163
 164    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 165        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 166    }
 167}
 168
 169fn interpolate(
 170    old_snapshot: &BufferSnapshot,
 171    new_snapshot: &BufferSnapshot,
 172    current_edits: Arc<[(Range<Anchor>, String)]>,
 173) -> Option<Vec<(Range<Anchor>, String)>> {
 174    let mut edits = Vec::new();
 175
 176    let mut model_edits = current_edits.iter().peekable();
 177    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 178        while let Some((model_old_range, _)) = model_edits.peek() {
 179            let model_old_range = model_old_range.to_offset(old_snapshot);
 180            if model_old_range.end < user_edit.old.start {
 181                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 182                edits.push((model_old_range.clone(), model_new_text.clone()));
 183            } else {
 184                break;
 185            }
 186        }
 187
 188        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 189            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 190            if user_edit.old == model_old_offset_range {
 191                let user_new_text = new_snapshot
 192                    .text_for_range(user_edit.new.clone())
 193                    .collect::<String>();
 194
 195                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 196                    if !model_suffix.is_empty() {
 197                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 198                        edits.push((anchor..anchor, model_suffix.to_string()));
 199                    }
 200
 201                    model_edits.next();
 202                    continue;
 203                }
 204            }
 205        }
 206
 207        return None;
 208    }
 209
 210    edits.extend(model_edits.cloned());
 211
 212    if edits.is_empty() { None } else { Some(edits) }
 213}
 214
 215impl std::fmt::Debug for EditPrediction {
 216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 217        f.debug_struct("EditPrediction")
 218            .field("id", &self.id)
 219            .field("path", &self.path)
 220            .field("edits", &self.edits)
 221            .finish_non_exhaustive()
 222    }
 223}
 224
 225pub struct Zeta {
 226    workspace: WeakEntity<Workspace>,
 227    client: Arc<Client>,
 228    events: VecDeque<Event>,
 229    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 230    shown_completions: VecDeque<EditPrediction>,
 231    rated_completions: HashSet<EditPredictionId>,
 232    data_collection_choice: Entity<DataCollectionChoice>,
 233    llm_token: LlmApiToken,
 234    _llm_token_subscription: Subscription,
 235    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 236    update_required: bool,
 237    user_store: Entity<UserStore>,
 238    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 239    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 240}
 241
 242impl Zeta {
 243    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 244        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 245    }
 246
 247    pub fn register(
 248        workspace: Option<Entity<Workspace>>,
 249        worktree: Option<Entity<Worktree>>,
 250        client: Arc<Client>,
 251        user_store: Entity<UserStore>,
 252        cx: &mut App,
 253    ) -> Entity<Self> {
 254        let this = Self::global(cx).unwrap_or_else(|| {
 255            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 256            cx.set_global(ZetaGlobal(entity.clone()));
 257            entity
 258        });
 259
 260        this.update(cx, move |this, cx| {
 261            if let Some(worktree) = worktree {
 262                let worktree_id = worktree.read(cx).id();
 263                this.license_detection_watchers
 264                    .entry(worktree_id)
 265                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 266            }
 267        });
 268
 269        this
 270    }
 271
 272    pub fn clear_history(&mut self) {
 273        self.events.clear();
 274    }
 275
 276    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 277        self.user_store.read(cx).edit_prediction_usage()
 278    }
 279
 280    fn new(
 281        workspace: Option<Entity<Workspace>>,
 282        client: Arc<Client>,
 283        user_store: Entity<UserStore>,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 287
 288        let data_collection_choice = Self::load_data_collection_choices();
 289        let data_collection_choice = cx.new(|_| data_collection_choice);
 290
 291        if let Some(workspace) = &workspace {
 292            cx.subscribe(
 293                &workspace.read(cx).project().clone(),
 294                |this, _workspace, event, _cx| match event {
 295                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 296                        this.push_recent_project_entry(*project_entry_id)
 297                    }
 298                    _ => {}
 299                },
 300            )
 301            .detach();
 302        }
 303
 304        Self {
 305            workspace: workspace.map_or_else(
 306                || WeakEntity::new_invalid(),
 307                |workspace| workspace.downgrade(),
 308            ),
 309            client,
 310            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 311            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 312            rated_completions: HashSet::default(),
 313            registered_buffers: HashMap::default(),
 314            data_collection_choice,
 315            llm_token: LlmApiToken::default(),
 316            _llm_token_subscription: cx.subscribe(
 317                &refresh_llm_token_listener,
 318                |this, _listener, _event, cx| {
 319                    let client = this.client.clone();
 320                    let llm_token = this.llm_token.clone();
 321                    cx.spawn(async move |_this, _cx| {
 322                        llm_token.refresh(&client).await?;
 323                        anyhow::Ok(())
 324                    })
 325                    .detach_and_log_err(cx);
 326                },
 327            ),
 328            update_required: false,
 329            license_detection_watchers: HashMap::default(),
 330            user_store,
 331            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 332        }
 333    }
 334
 335    fn push_event(&mut self, event: Event) {
 336        if let Some(Event::BufferChange {
 337            new_snapshot: last_new_snapshot,
 338            timestamp: last_timestamp,
 339            ..
 340        }) = self.events.back_mut()
 341        {
 342            // Coalesce edits for the same buffer when they happen one after the other.
 343            let Event::BufferChange {
 344                old_snapshot,
 345                new_snapshot,
 346                timestamp,
 347            } = &event;
 348
 349            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 350                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 351                && old_snapshot.version == last_new_snapshot.version
 352            {
 353                *last_new_snapshot = new_snapshot.clone();
 354                *last_timestamp = *timestamp;
 355                return;
 356            }
 357        }
 358
 359        if self.events.len() >= MAX_EVENT_COUNT {
 360            // These are halved instead of popping to improve prompt caching.
 361            self.events.drain(..MAX_EVENT_COUNT / 2);
 362        }
 363
 364        self.events.push_back(event);
 365    }
 366
 367    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 368        let buffer_id = buffer.entity_id();
 369        let weak_buffer = buffer.downgrade();
 370
 371        if let std::collections::hash_map::Entry::Vacant(entry) =
 372            self.registered_buffers.entry(buffer_id)
 373        {
 374            let snapshot = buffer.read(cx).snapshot();
 375
 376            entry.insert(RegisteredBuffer {
 377                snapshot,
 378                _subscriptions: [
 379                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 380                        this.handle_buffer_event(buffer, event, cx);
 381                    }),
 382                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 383                        this.registered_buffers.remove(&weak_buffer.entity_id());
 384                    }),
 385                ],
 386            });
 387        };
 388    }
 389
 390    fn handle_buffer_event(
 391        &mut self,
 392        buffer: Entity<Buffer>,
 393        event: &language::BufferEvent,
 394        cx: &mut Context<Self>,
 395    ) {
 396        if let language::BufferEvent::Edited = event {
 397            self.report_changes_for_buffer(&buffer, cx);
 398        }
 399    }
 400
 401    fn request_completion_impl<F, R>(
 402        &mut self,
 403        workspace: Option<Entity<Workspace>>,
 404        project: Option<&Entity<Project>>,
 405        buffer: &Entity<Buffer>,
 406        cursor: language::Anchor,
 407        can_collect_data: CanCollectData,
 408        cx: &mut Context<Self>,
 409        perform_predict_edits: F,
 410    ) -> Task<Result<Option<EditPrediction>>>
 411    where
 412        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 413        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 414            + Send
 415            + 'static,
 416    {
 417        let buffer = buffer.clone();
 418        let buffer_snapshotted_at = Instant::now();
 419        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 420        let zeta = cx.entity();
 421        let events = self.events.clone();
 422        let client = self.client.clone();
 423        let llm_token = self.llm_token.clone();
 424        let app_version = AppVersion::global(cx);
 425
 426        let git_info = if matches!(can_collect_data, CanCollectData(true)) {
 427            self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
 428        } else {
 429            None
 430        };
 431
 432        let full_path: Arc<Path> = snapshot
 433            .file()
 434            .map(|f| Arc::from(f.full_path(cx).as_path()))
 435            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 436        let full_path_str = full_path.to_string_lossy().to_string();
 437        let cursor_point = cursor.to_point(&snapshot);
 438        let cursor_offset = cursor_point.to_offset(&snapshot);
 439        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 440        let gather_task = gather_context(
 441            project,
 442            full_path_str,
 443            &snapshot,
 444            cursor_point,
 445            make_events_prompt,
 446            can_collect_data,
 447            git_info,
 448            cx,
 449        );
 450
 451        cx.spawn(async move |this, cx| {
 452            let GatherContextOutput {
 453                body,
 454                editable_range,
 455            } = gather_task.await?;
 456            let done_gathering_context_at = Instant::now();
 457
 458            log::debug!(
 459                "Events:\n{}\nExcerpt:\n{:?}",
 460                body.input_events,
 461                body.input_excerpt
 462            );
 463
 464            let input_outline = body.outline.clone().unwrap_or_default();
 465            let input_events = body.input_events.clone();
 466            let input_excerpt = body.input_excerpt.clone();
 467
 468            let response = perform_predict_edits(PerformPredictEditsParams {
 469                client,
 470                llm_token,
 471                app_version,
 472                body,
 473            })
 474            .await;
 475            let (response, usage) = match response {
 476                Ok(response) => response,
 477                Err(err) => {
 478                    if err.is::<ZedUpdateRequiredError>() {
 479                        cx.update(|cx| {
 480                            zeta.update(cx, |zeta, _cx| {
 481                                zeta.update_required = true;
 482                            });
 483
 484                            if let Some(workspace) = workspace {
 485                                workspace.update(cx, |workspace, cx| {
 486                                    workspace.show_notification(
 487                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 488                                        cx,
 489                                        |cx| {
 490                                            cx.new(|cx| {
 491                                                ErrorMessagePrompt::new(err.to_string(), cx)
 492                                                    .with_link_button(
 493                                                        "Update Zed",
 494                                                        "https://zed.dev/releases",
 495                                                    )
 496                                            })
 497                                        },
 498                                    );
 499                                });
 500                            }
 501                        })
 502                        .ok();
 503                    }
 504
 505                    return Err(err);
 506                }
 507            };
 508
 509            let received_response_at = Instant::now();
 510            log::debug!("completion response: {}", &response.output_excerpt);
 511
 512            if let Some(usage) = usage {
 513                this.update(cx, |this, cx| {
 514                    this.user_store.update(cx, |user_store, cx| {
 515                        user_store.update_edit_prediction_usage(usage, cx);
 516                    });
 517                })
 518                .ok();
 519            }
 520
 521            let edit_prediction = Self::process_completion_response(
 522                response,
 523                buffer,
 524                &snapshot,
 525                editable_range,
 526                cursor_offset,
 527                full_path,
 528                input_outline,
 529                input_events,
 530                input_excerpt,
 531                buffer_snapshotted_at,
 532                cx,
 533            )
 534            .await;
 535
 536            let finished_at = Instant::now();
 537
 538            // record latency for ~1% of requests
 539            if rand::random::<u8>() <= 2 {
 540                telemetry::event!(
 541                    "Edit Prediction Request",
 542                    context_latency = done_gathering_context_at
 543                        .duration_since(buffer_snapshotted_at)
 544                        .as_millis(),
 545                    request_latency = received_response_at
 546                        .duration_since(done_gathering_context_at)
 547                        .as_millis(),
 548                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 549                );
 550            }
 551
 552            edit_prediction
 553        })
 554    }
 555
 556    // Generates several example completions of various states to fill the Zeta completion modal
 557    #[cfg(any(test, feature = "test-support"))]
 558    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 559        use language::Point;
 560
 561        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 562            And maybe a short line
 563
 564            Then a few lines
 565
 566            and then another
 567            "#};
 568
 569        let project = None;
 570        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 571        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 572
 573        let completion_tasks = vec![
 574            self.fake_completion(
 575                project,
 576                &buffer,
 577                position,
 578                PredictEditsResponse {
 579                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 580                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 581a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 582[here's an edit]
 583And maybe a short line
 584Then a few lines
 585and then another
 586{EDITABLE_REGION_END_MARKER}
 587                        ", ),
 588                },
 589                cx,
 590            ),
 591            self.fake_completion(
 592                project,
 593                &buffer,
 594                position,
 595                PredictEditsResponse {
 596                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 597                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 598a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 599And maybe a short line
 600[and another edit]
 601Then a few lines
 602and then another
 603{EDITABLE_REGION_END_MARKER}
 604                        "#),
 605                },
 606                cx,
 607            ),
 608            self.fake_completion(
 609                project,
 610                &buffer,
 611                position,
 612                PredictEditsResponse {
 613                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 614                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 615a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 616And maybe a short line
 617
 618Then a few lines
 619
 620and then another
 621{EDITABLE_REGION_END_MARKER}
 622                        "#),
 623                },
 624                cx,
 625            ),
 626            self.fake_completion(
 627                project,
 628                &buffer,
 629                position,
 630                PredictEditsResponse {
 631                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 632                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 633a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 634And maybe a short line
 635
 636Then a few lines
 637
 638and then another
 639{EDITABLE_REGION_END_MARKER}
 640                        "#),
 641                },
 642                cx,
 643            ),
 644            self.fake_completion(
 645                project,
 646                &buffer,
 647                position,
 648                PredictEditsResponse {
 649                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 650                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 652And maybe a short line
 653Then a few lines
 654[a third completion]
 655and then another
 656{EDITABLE_REGION_END_MARKER}
 657                        "#),
 658                },
 659                cx,
 660            ),
 661            self.fake_completion(
 662                project,
 663                &buffer,
 664                position,
 665                PredictEditsResponse {
 666                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 667                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 668a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 669And maybe a short line
 670and then another
 671[fourth completion example]
 672{EDITABLE_REGION_END_MARKER}
 673                        "#),
 674                },
 675                cx,
 676            ),
 677            self.fake_completion(
 678                project,
 679                &buffer,
 680                position,
 681                PredictEditsResponse {
 682                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 683                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 684a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 685And maybe a short line
 686Then a few lines
 687and then another
 688[fifth and final completion]
 689{EDITABLE_REGION_END_MARKER}
 690                        "#),
 691                },
 692                cx,
 693            ),
 694        ];
 695
 696        cx.spawn(async move |zeta, cx| {
 697            for task in completion_tasks {
 698                task.await.unwrap();
 699            }
 700
 701            zeta.update(cx, |zeta, _cx| {
 702                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 703                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 704            })
 705            .ok();
 706        })
 707    }
 708
 709    #[cfg(any(test, feature = "test-support"))]
 710    pub fn fake_completion(
 711        &mut self,
 712        project: Option<&Entity<Project>>,
 713        buffer: &Entity<Buffer>,
 714        position: language::Anchor,
 715        response: PredictEditsResponse,
 716        cx: &mut Context<Self>,
 717    ) -> Task<Result<Option<EditPrediction>>> {
 718        use std::future::ready;
 719
 720        self.request_completion_impl(
 721            None,
 722            project,
 723            buffer,
 724            position,
 725            CanCollectData(false),
 726            cx,
 727            |_params| ready(Ok((response, None))),
 728        )
 729    }
 730
 731    pub fn request_completion(
 732        &mut self,
 733        project: Option<&Entity<Project>>,
 734        buffer: &Entity<Buffer>,
 735        position: language::Anchor,
 736        can_collect_data: CanCollectData,
 737        cx: &mut Context<Self>,
 738    ) -> Task<Result<Option<EditPrediction>>> {
 739        self.request_completion_impl(
 740            self.workspace.upgrade(),
 741            project,
 742            buffer,
 743            position,
 744            can_collect_data,
 745            cx,
 746            Self::perform_predict_edits,
 747        )
 748    }
 749
 750    pub fn perform_predict_edits(
 751        params: PerformPredictEditsParams,
 752    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 753        async move {
 754            let PerformPredictEditsParams {
 755                client,
 756                llm_token,
 757                app_version,
 758                body,
 759                ..
 760            } = params;
 761
 762            let http_client = client.http_client();
 763            let mut token = llm_token.acquire(&client).await?;
 764            let mut did_retry = false;
 765
 766            loop {
 767                let request_builder = http_client::Request::builder().method(Method::POST);
 768                let request_builder =
 769                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 770                        request_builder.uri(predict_edits_url)
 771                    } else {
 772                        request_builder.uri(
 773                            http_client
 774                                .build_zed_llm_url("/predict_edits/v2", &[])?
 775                                .as_ref(),
 776                        )
 777                    };
 778                let request = request_builder
 779                    .header("Content-Type", "application/json")
 780                    .header("Authorization", format!("Bearer {}", token))
 781                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 782                    .body(serde_json::to_string(&body)?.into())?;
 783
 784                let mut response = http_client.send(request).await?;
 785
 786                if let Some(minimum_required_version) = response
 787                    .headers()
 788                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 789                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 790                {
 791                    anyhow::ensure!(
 792                        app_version >= minimum_required_version,
 793                        ZedUpdateRequiredError {
 794                            minimum_version: minimum_required_version
 795                        }
 796                    );
 797                }
 798
 799                if response.status().is_success() {
 800                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 801
 802                    let mut body = String::new();
 803                    response.body_mut().read_to_string(&mut body).await?;
 804                    return Ok((serde_json::from_str(&body)?, usage));
 805                } else if !did_retry
 806                    && response
 807                        .headers()
 808                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 809                        .is_some()
 810                {
 811                    did_retry = true;
 812                    token = llm_token.refresh(&client).await?;
 813                } else {
 814                    let mut body = String::new();
 815                    response.body_mut().read_to_string(&mut body).await?;
 816                    anyhow::bail!(
 817                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 818                        response.status(),
 819                        body
 820                    );
 821                }
 822            }
 823        }
 824    }
 825
 826    fn accept_edit_prediction(
 827        &mut self,
 828        request_id: EditPredictionId,
 829        cx: &mut Context<Self>,
 830    ) -> Task<Result<()>> {
 831        let client = self.client.clone();
 832        let llm_token = self.llm_token.clone();
 833        let app_version = AppVersion::global(cx);
 834        cx.spawn(async move |this, cx| {
 835            let http_client = client.http_client();
 836            let mut response = llm_token_retry(&llm_token, &client, |token| {
 837                let request_builder = http_client::Request::builder().method(Method::POST);
 838                let request_builder =
 839                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 840                        request_builder.uri(accept_prediction_url)
 841                    } else {
 842                        request_builder.uri(
 843                            http_client
 844                                .build_zed_llm_url("/predict_edits/accept", &[])?
 845                                .as_ref(),
 846                        )
 847                    };
 848                Ok(request_builder
 849                    .header("Content-Type", "application/json")
 850                    .header("Authorization", format!("Bearer {}", token))
 851                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 852                    .body(
 853                        serde_json::to_string(&AcceptEditPredictionBody {
 854                            request_id: request_id.0,
 855                        })?
 856                        .into(),
 857                    )?)
 858            })
 859            .await?;
 860
 861            if let Some(minimum_required_version) = response
 862                .headers()
 863                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 864                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 865                && app_version < minimum_required_version
 866            {
 867                return Err(anyhow!(ZedUpdateRequiredError {
 868                    minimum_version: minimum_required_version
 869                }));
 870            }
 871
 872            if response.status().is_success() {
 873                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 874                    this.update(cx, |this, cx| {
 875                        this.user_store.update(cx, |user_store, cx| {
 876                            user_store.update_edit_prediction_usage(usage, cx);
 877                        });
 878                    })?;
 879                }
 880
 881                Ok(())
 882            } else {
 883                let mut body = String::new();
 884                response.body_mut().read_to_string(&mut body).await?;
 885                Err(anyhow!(
 886                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 887                    response.status(),
 888                    body
 889                ))
 890            }
 891        })
 892    }
 893
 894    fn process_completion_response(
 895        prediction_response: PredictEditsResponse,
 896        buffer: Entity<Buffer>,
 897        snapshot: &BufferSnapshot,
 898        editable_range: Range<usize>,
 899        cursor_offset: usize,
 900        path: Arc<Path>,
 901        input_outline: String,
 902        input_events: String,
 903        input_excerpt: String,
 904        buffer_snapshotted_at: Instant,
 905        cx: &AsyncApp,
 906    ) -> Task<Result<Option<EditPrediction>>> {
 907        let snapshot = snapshot.clone();
 908        let request_id = prediction_response.request_id;
 909        let output_excerpt = prediction_response.output_excerpt;
 910        cx.spawn(async move |cx| {
 911            let output_excerpt: Arc<str> = output_excerpt.into();
 912
 913            let edits: Arc<[(Range<Anchor>, String)]> = cx
 914                .background_spawn({
 915                    let output_excerpt = output_excerpt.clone();
 916                    let editable_range = editable_range.clone();
 917                    let snapshot = snapshot.clone();
 918                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 919                })
 920                .await?
 921                .into();
 922
 923            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 924                let edits = edits.clone();
 925                |buffer, cx| {
 926                    let new_snapshot = buffer.snapshot();
 927                    let edits: Arc<[(Range<Anchor>, String)]> =
 928                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 929                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 930                }
 931            })?
 932            else {
 933                return anyhow::Ok(None);
 934            };
 935
 936            let edit_preview = edit_preview.await;
 937
 938            Ok(Some(EditPrediction {
 939                id: EditPredictionId(request_id),
 940                path,
 941                excerpt_range: editable_range,
 942                cursor_offset,
 943                edits,
 944                edit_preview,
 945                snapshot,
 946                input_outline: input_outline.into(),
 947                input_events: input_events.into(),
 948                input_excerpt: input_excerpt.into(),
 949                output_excerpt,
 950                buffer_snapshotted_at,
 951                response_received_at: Instant::now(),
 952            }))
 953        })
 954    }
 955
 956    fn parse_edits(
 957        output_excerpt: Arc<str>,
 958        editable_range: Range<usize>,
 959        snapshot: &BufferSnapshot,
 960    ) -> Result<Vec<(Range<Anchor>, String)>> {
 961        let content = output_excerpt.replace(CURSOR_MARKER, "");
 962
 963        let start_markers = content
 964            .match_indices(EDITABLE_REGION_START_MARKER)
 965            .collect::<Vec<_>>();
 966        anyhow::ensure!(
 967            start_markers.len() == 1,
 968            "expected exactly one start marker, found {}",
 969            start_markers.len()
 970        );
 971
 972        let end_markers = content
 973            .match_indices(EDITABLE_REGION_END_MARKER)
 974            .collect::<Vec<_>>();
 975        anyhow::ensure!(
 976            end_markers.len() == 1,
 977            "expected exactly one end marker, found {}",
 978            end_markers.len()
 979        );
 980
 981        let sof_markers = content
 982            .match_indices(START_OF_FILE_MARKER)
 983            .collect::<Vec<_>>();
 984        anyhow::ensure!(
 985            sof_markers.len() <= 1,
 986            "expected at most one start-of-file marker, found {}",
 987            sof_markers.len()
 988        );
 989
 990        let codefence_start = start_markers[0].0;
 991        let content = &content[codefence_start..];
 992
 993        let newline_ix = content.find('\n').context("could not find newline")?;
 994        let content = &content[newline_ix + 1..];
 995
 996        let codefence_end = content
 997            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 998            .context("could not find end marker")?;
 999        let new_text = &content[..codefence_end];
1000
1001        let old_text = snapshot
1002            .text_for_range(editable_range.clone())
1003            .collect::<String>();
1004
1005        Ok(Self::compute_edits(
1006            old_text,
1007            new_text,
1008            editable_range.start,
1009            snapshot,
1010        ))
1011    }
1012
1013    pub fn compute_edits(
1014        old_text: String,
1015        new_text: &str,
1016        offset: usize,
1017        snapshot: &BufferSnapshot,
1018    ) -> Vec<(Range<Anchor>, String)> {
1019        text_diff(&old_text, new_text)
1020            .into_iter()
1021            .map(|(mut old_range, new_text)| {
1022                old_range.start += offset;
1023                old_range.end += offset;
1024
1025                let prefix_len = common_prefix(
1026                    snapshot.chars_for_range(old_range.clone()),
1027                    new_text.chars(),
1028                );
1029                old_range.start += prefix_len;
1030
1031                let suffix_len = common_prefix(
1032                    snapshot.reversed_chars_for_range(old_range.clone()),
1033                    new_text[prefix_len..].chars().rev(),
1034                );
1035                old_range.end = old_range.end.saturating_sub(suffix_len);
1036
1037                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1038                let range = if old_range.is_empty() {
1039                    let anchor = snapshot.anchor_after(old_range.start);
1040                    anchor..anchor
1041                } else {
1042                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1043                };
1044                (range, new_text)
1045            })
1046            .collect()
1047    }
1048
1049    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1050        self.rated_completions.contains(&completion_id)
1051    }
1052
1053    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1054        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1055            let completion = self.shown_completions.pop_back().unwrap();
1056            self.rated_completions.remove(&completion.id);
1057        }
1058        self.shown_completions.push_front(completion.clone());
1059        cx.notify();
1060    }
1061
1062    pub fn rate_completion(
1063        &mut self,
1064        completion: &EditPrediction,
1065        rating: EditPredictionRating,
1066        feedback: String,
1067        cx: &mut Context<Self>,
1068    ) {
1069        self.rated_completions.insert(completion.id);
1070        telemetry::event!(
1071            "Edit Prediction Rated",
1072            rating,
1073            input_events = completion.input_events,
1074            input_excerpt = completion.input_excerpt,
1075            input_outline = completion.input_outline,
1076            output_excerpt = completion.output_excerpt,
1077            feedback
1078        );
1079        self.client.telemetry().flush_events().detach();
1080        cx.notify();
1081    }
1082
1083    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1084        self.shown_completions.iter()
1085    }
1086
1087    pub fn shown_completions_len(&self) -> usize {
1088        self.shown_completions.len()
1089    }
1090
1091    fn report_changes_for_buffer(
1092        &mut self,
1093        buffer: &Entity<Buffer>,
1094        cx: &mut Context<Self>,
1095    ) -> BufferSnapshot {
1096        self.register_buffer(buffer, cx);
1097
1098        let registered_buffer = self
1099            .registered_buffers
1100            .get_mut(&buffer.entity_id())
1101            .unwrap();
1102        let new_snapshot = buffer.read(cx).snapshot();
1103
1104        if new_snapshot.version != registered_buffer.snapshot.version {
1105            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1106            self.push_event(Event::BufferChange {
1107                old_snapshot,
1108                new_snapshot: new_snapshot.clone(),
1109                timestamp: Instant::now(),
1110            });
1111        }
1112
1113        new_snapshot
1114    }
1115
1116    fn load_data_collection_choices() -> DataCollectionChoice {
1117        let choice = KEY_VALUE_STORE
1118            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1119            .log_err()
1120            .flatten();
1121
1122        match choice.as_deref() {
1123            Some("true") => DataCollectionChoice::Enabled,
1124            Some("false") => DataCollectionChoice::Disabled,
1125            Some(_) => {
1126                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1127                DataCollectionChoice::NotAnswered
1128            }
1129            None => DataCollectionChoice::NotAnswered,
1130        }
1131    }
1132
1133    fn gather_git_info(
1134        &mut self,
1135        project: Option<&Entity<Project>>,
1136        buffer_snapshotted_at: &Instant,
1137        snapshot: &BufferSnapshot,
1138        cx: &Context<Self>,
1139    ) -> Option<PredictEditsGitInfo> {
1140        let project = project?.read(cx);
1141        let file = snapshot.file()?;
1142        let project_path = ProjectPath::from_file(file.as_ref(), cx);
1143        let entry = project.entry_for_path(&project_path, cx)?;
1144        if !worktree_entry_eligible_for_collection(&entry) {
1145            return None;
1146        }
1147
1148        let git_store = project.git_store().read(cx);
1149        let (repository, repo_path) =
1150            git_store.repository_and_path_for_project_path(&project_path, cx)?;
1151        let repo_path_str = repo_path.to_str()?;
1152
1153        let repository = repository.read(cx);
1154        let head_sha = repository
1155            .head_commit
1156            .as_ref()
1157            .map(|head_commit| head_commit.sha.to_string());
1158        let remote_origin_url = repository.remote_origin_url.clone();
1159        let remote_upstream_url = repository.remote_upstream_url.clone();
1160        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1161            return None;
1162        }
1163
1164        let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1165
1166        Some(PredictEditsGitInfo {
1167            input_path: Some(repo_path_str.to_string()),
1168            head_sha,
1169            remote_origin_url,
1170            remote_upstream_url,
1171            recent_files: Some(recent_files),
1172        })
1173    }
1174
1175    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1176        let now = Instant::now();
1177        if let Some(existing_ix) = self
1178            .recent_project_entries
1179            .iter()
1180            .rposition(|(id, _)| *id == project_entry_id)
1181        {
1182            self.recent_project_entries.remove(existing_ix);
1183        }
1184        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1185            self.recent_project_entries.pop_front();
1186        }
1187        self.recent_project_entries
1188            .push_back((project_entry_id, now));
1189    }
1190
1191    fn recent_files(
1192        &mut self,
1193        now: &Instant,
1194        repository: &Repository,
1195        cx: &Context<Self>,
1196    ) -> Vec<PredictEditsRecentFile> {
1197        let Ok(project) = self
1198            .workspace
1199            .read_with(cx, |workspace, _cx| workspace.project().clone())
1200        else {
1201            return Vec::new();
1202        };
1203        let mut results = Vec::new();
1204        for ix in (0..self.recent_project_entries.len()).rev() {
1205            let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1206            if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1207                && let worktree = worktree.read(cx)
1208                && let Some(entry) = worktree.entry_for_id(*entry_id)
1209                && worktree_entry_eligible_for_collection(entry)
1210            {
1211                let project_path = ProjectPath {
1212                    worktree_id: worktree.id(),
1213                    path: entry.path.clone(),
1214                };
1215                let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1216                else {
1217                    // entry not removed since queries involving other repositories might occur later
1218                    continue;
1219                };
1220                let Some(repo_path_str) = repo_path.to_str() else {
1221                    // paths may not be valid UTF-8
1222                    self.recent_project_entries.remove(ix);
1223                    continue;
1224                };
1225                if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1226                    self.recent_project_entries.remove(ix);
1227                    continue;
1228                }
1229                let Ok(active_to_now_ms) =
1230                    now.duration_since(*last_active_at).as_millis().try_into()
1231                else {
1232                    self.recent_project_entries.remove(ix);
1233                    continue;
1234                };
1235                results.push(PredictEditsRecentFile {
1236                    path: repo_path_str.to_string(),
1237                    active_to_now_ms,
1238                });
1239            } else {
1240                self.recent_project_entries.remove(ix);
1241            }
1242        }
1243        results
1244    }
1245}
1246
1247fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1248    entry.is_file()
1249        && entry.is_created()
1250        && !entry.is_ignored
1251        && !entry.is_private
1252        && !entry.is_external
1253        && !entry.is_fifo
1254}
1255
1256pub struct PerformPredictEditsParams {
1257    pub client: Arc<Client>,
1258    pub llm_token: LlmApiToken,
1259    pub app_version: SemanticVersion,
1260    pub body: PredictEditsBody,
1261}
1262
1263#[derive(Error, Debug)]
1264#[error(
1265    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1266)]
1267pub struct ZedUpdateRequiredError {
1268    minimum_version: SemanticVersion,
1269}
1270
1271fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1272    a.zip(b)
1273        .take_while(|(a, b)| a == b)
1274        .map(|(a, _)| a.len_utf8())
1275        .sum()
1276}
1277
1278pub struct GatherContextOutput {
1279    pub body: PredictEditsBody,
1280    pub editable_range: Range<usize>,
1281}
1282
1283pub fn gather_context(
1284    project: Option<&Entity<Project>>,
1285    full_path_str: String,
1286    snapshot: &BufferSnapshot,
1287    cursor_point: language::Point,
1288    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1289    can_collect_data: CanCollectData,
1290    git_info: Option<PredictEditsGitInfo>,
1291    cx: &App,
1292) -> Task<Result<GatherContextOutput>> {
1293    let local_lsp_store =
1294        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1295    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1296        if matches!(can_collect_data, CanCollectData(true))
1297            && let Some(local_lsp_store) = local_lsp_store
1298        {
1299            snapshot
1300                .diagnostic_groups(None)
1301                .into_iter()
1302                .filter_map(|(language_server_id, diagnostic_group)| {
1303                    let language_server =
1304                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1305                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1306                    let language_server_name = language_server.name().to_string();
1307                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1308                    Some((language_server_name, serialized))
1309                })
1310                .collect::<Vec<_>>()
1311        } else {
1312            Vec::new()
1313        };
1314
1315    cx.background_spawn({
1316        let snapshot = snapshot.clone();
1317        async move {
1318            let diagnostic_groups = if diagnostic_groups.is_empty()
1319                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1320            {
1321                None
1322            } else {
1323                Some(diagnostic_groups)
1324            };
1325
1326            let input_excerpt = excerpt_for_cursor_position(
1327                cursor_point,
1328                &full_path_str,
1329                &snapshot,
1330                MAX_REWRITE_TOKENS,
1331                MAX_CONTEXT_TOKENS,
1332            );
1333            let input_events = make_events_prompt();
1334            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1335
1336            let body = PredictEditsBody {
1337                input_events,
1338                input_excerpt: input_excerpt.prompt,
1339                can_collect_data: can_collect_data.0,
1340                diagnostic_groups,
1341                git_info,
1342                outline: None,
1343                speculated_output: None,
1344            };
1345
1346            Ok(GatherContextOutput {
1347                body,
1348                editable_range,
1349            })
1350        }
1351    })
1352}
1353
1354fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1355    let mut result = String::new();
1356    for event in events.iter().rev() {
1357        let event_string = event.to_prompt();
1358        let event_tokens = tokens_for_bytes(event_string.len());
1359        if event_tokens > remaining_tokens {
1360            break;
1361        }
1362
1363        if !result.is_empty() {
1364            result.insert_str(0, "\n\n");
1365        }
1366        result.insert_str(0, &event_string);
1367        remaining_tokens -= event_tokens;
1368    }
1369    result
1370}
1371
1372struct RegisteredBuffer {
1373    snapshot: BufferSnapshot,
1374    _subscriptions: [gpui::Subscription; 2],
1375}
1376
1377#[derive(Clone)]
1378pub enum Event {
1379    BufferChange {
1380        old_snapshot: BufferSnapshot,
1381        new_snapshot: BufferSnapshot,
1382        timestamp: Instant,
1383    },
1384}
1385
1386impl Event {
1387    fn to_prompt(&self) -> String {
1388        match self {
1389            Event::BufferChange {
1390                old_snapshot,
1391                new_snapshot,
1392                ..
1393            } => {
1394                let mut prompt = String::new();
1395
1396                let old_path = old_snapshot
1397                    .file()
1398                    .map(|f| f.path().as_ref())
1399                    .unwrap_or(Path::new("untitled"));
1400                let new_path = new_snapshot
1401                    .file()
1402                    .map(|f| f.path().as_ref())
1403                    .unwrap_or(Path::new("untitled"));
1404                if old_path != new_path {
1405                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1406                }
1407
1408                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1409                if !diff.is_empty() {
1410                    write!(
1411                        prompt,
1412                        "User edited {:?}:\n```diff\n{}\n```",
1413                        new_path, diff
1414                    )
1415                    .unwrap();
1416                }
1417
1418                prompt
1419            }
1420        }
1421    }
1422}
1423
1424#[derive(Debug, Clone)]
1425struct CurrentEditPrediction {
1426    buffer_id: EntityId,
1427    completion: EditPrediction,
1428}
1429
1430impl CurrentEditPrediction {
1431    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1432        if self.buffer_id != old_completion.buffer_id {
1433            return true;
1434        }
1435
1436        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1437            return true;
1438        };
1439        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1440            return false;
1441        };
1442
1443        if old_edits.len() == 1 && new_edits.len() == 1 {
1444            let (old_range, old_text) = &old_edits[0];
1445            let (new_range, new_text) = &new_edits[0];
1446            new_range == old_range && new_text.starts_with(old_text)
1447        } else {
1448            true
1449        }
1450    }
1451}
1452
1453struct PendingCompletion {
1454    id: usize,
1455    _task: Task<()>,
1456}
1457
1458#[derive(Debug, Clone, Copy)]
1459pub enum DataCollectionChoice {
1460    NotAnswered,
1461    Enabled,
1462    Disabled,
1463}
1464
1465impl DataCollectionChoice {
1466    pub fn is_enabled(self) -> bool {
1467        match self {
1468            Self::Enabled => true,
1469            Self::NotAnswered | Self::Disabled => false,
1470        }
1471    }
1472
1473    pub fn is_answered(self) -> bool {
1474        match self {
1475            Self::Enabled | Self::Disabled => true,
1476            Self::NotAnswered => false,
1477        }
1478    }
1479
1480    pub fn toggle(&self) -> DataCollectionChoice {
1481        match self {
1482            Self::Enabled => Self::Disabled,
1483            Self::Disabled => Self::Enabled,
1484            Self::NotAnswered => Self::Enabled,
1485        }
1486    }
1487}
1488
1489impl From<bool> for DataCollectionChoice {
1490    fn from(value: bool) -> Self {
1491        match value {
1492            true => DataCollectionChoice::Enabled,
1493            false => DataCollectionChoice::Disabled,
1494        }
1495    }
1496}
1497
1498pub struct ProviderDataCollection {
1499    /// When set to None, data collection is not possible in the provider buffer
1500    choice: Option<Entity<DataCollectionChoice>>,
1501    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1502}
1503
1504#[derive(Debug, Clone, Copy)]
1505pub struct CanCollectData(pub bool);
1506
1507impl ProviderDataCollection {
1508    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1509        let choice_and_watcher = buffer.and_then(|buffer| {
1510            let file = buffer.read(cx).file()?;
1511
1512            if !file.is_local() || file.is_private() {
1513                return None;
1514            }
1515
1516            let zeta = zeta.read(cx);
1517            let choice = zeta.data_collection_choice.clone();
1518
1519            let license_detection_watcher = zeta
1520                .license_detection_watchers
1521                .get(&file.worktree_id(cx))
1522                .cloned()?;
1523
1524            Some((choice, license_detection_watcher))
1525        });
1526
1527        if let Some((choice, watcher)) = choice_and_watcher {
1528            ProviderDataCollection {
1529                choice: Some(choice),
1530                license_detection_watcher: Some(watcher),
1531            }
1532        } else {
1533            ProviderDataCollection {
1534                choice: None,
1535                license_detection_watcher: None,
1536            }
1537        }
1538    }
1539
1540    pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1541        CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1542    }
1543
1544    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1545        self.choice
1546            .as_ref()
1547            .is_some_and(|choice| choice.read(cx).is_enabled())
1548    }
1549
1550    fn is_project_open_source(&self) -> bool {
1551        self.license_detection_watcher
1552            .as_ref()
1553            .is_some_and(|watcher| watcher.is_project_open_source())
1554    }
1555
1556    pub fn toggle(&mut self, cx: &mut App) {
1557        if let Some(choice) = self.choice.as_mut() {
1558            let new_choice = choice.update(cx, |choice, _cx| {
1559                let new_choice = choice.toggle();
1560                *choice = new_choice;
1561                new_choice
1562            });
1563
1564            db::write_and_log(cx, move || {
1565                KEY_VALUE_STORE.write_kvp(
1566                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1567                    new_choice.is_enabled().to_string(),
1568                )
1569            });
1570        }
1571    }
1572}
1573
1574async fn llm_token_retry(
1575    llm_token: &LlmApiToken,
1576    client: &Arc<Client>,
1577    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1578) -> Result<Response<AsyncBody>> {
1579    let mut did_retry = false;
1580    let http_client = client.http_client();
1581    let mut token = llm_token.acquire(client).await?;
1582    loop {
1583        let request = build_request(token.clone())?;
1584        let response = http_client.send(request).await?;
1585
1586        if !did_retry
1587            && !response.status().is_success()
1588            && response
1589                .headers()
1590                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1591                .is_some()
1592        {
1593            did_retry = true;
1594            token = llm_token.refresh(client).await?;
1595            continue;
1596        }
1597
1598        return Ok(response);
1599    }
1600}
1601
1602pub struct ZetaEditPredictionProvider {
1603    zeta: Entity<Zeta>,
1604    pending_completions: ArrayVec<PendingCompletion, 2>,
1605    next_pending_completion_id: usize,
1606    current_completion: Option<CurrentEditPrediction>,
1607    /// None if this is entirely disabled for this provider
1608    provider_data_collection: ProviderDataCollection,
1609    last_request_timestamp: Instant,
1610}
1611
1612impl ZetaEditPredictionProvider {
1613    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1614
1615    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1616        Self {
1617            zeta,
1618            pending_completions: ArrayVec::new(),
1619            next_pending_completion_id: 0,
1620            current_completion: None,
1621            provider_data_collection,
1622            last_request_timestamp: Instant::now(),
1623        }
1624    }
1625}
1626
1627impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1628    fn name() -> &'static str {
1629        "zed-predict"
1630    }
1631
1632    fn display_name() -> &'static str {
1633        "Zed's Edit Predictions"
1634    }
1635
1636    fn show_completions_in_menu() -> bool {
1637        true
1638    }
1639
1640    fn show_tab_accept_marker() -> bool {
1641        true
1642    }
1643
1644    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1645        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1646
1647        if self.provider_data_collection.is_data_collection_enabled(cx) {
1648            DataCollectionState::Enabled {
1649                is_project_open_source,
1650            }
1651        } else {
1652            DataCollectionState::Disabled {
1653                is_project_open_source,
1654            }
1655        }
1656    }
1657
1658    fn toggle_data_collection(&mut self, cx: &mut App) {
1659        self.provider_data_collection.toggle(cx);
1660    }
1661
1662    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1663        self.zeta.read(cx).usage(cx)
1664    }
1665
1666    fn is_enabled(
1667        &self,
1668        _buffer: &Entity<Buffer>,
1669        _cursor_position: language::Anchor,
1670        _cx: &App,
1671    ) -> bool {
1672        true
1673    }
1674    fn is_refreshing(&self) -> bool {
1675        !self.pending_completions.is_empty()
1676    }
1677
1678    fn refresh(
1679        &mut self,
1680        project: Option<Entity<Project>>,
1681        buffer: Entity<Buffer>,
1682        position: language::Anchor,
1683        _debounce: bool,
1684        cx: &mut Context<Self>,
1685    ) {
1686        if self.zeta.read(cx).update_required {
1687            return;
1688        }
1689
1690        if self
1691            .zeta
1692            .read(cx)
1693            .user_store
1694            .read_with(cx, |user_store, _cx| {
1695                user_store.account_too_young() || user_store.has_overdue_invoices()
1696            })
1697        {
1698            return;
1699        }
1700
1701        if let Some(current_completion) = self.current_completion.as_ref() {
1702            let snapshot = buffer.read(cx).snapshot();
1703            if current_completion
1704                .completion
1705                .interpolate(&snapshot)
1706                .is_some()
1707            {
1708                return;
1709            }
1710        }
1711
1712        let pending_completion_id = self.next_pending_completion_id;
1713        self.next_pending_completion_id += 1;
1714        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1715        let last_request_timestamp = self.last_request_timestamp;
1716
1717        let task = cx.spawn(async move |this, cx| {
1718            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1719                .checked_duration_since(Instant::now())
1720            {
1721                cx.background_executor().timer(timeout).await;
1722            }
1723
1724            let completion_request = this.update(cx, |this, cx| {
1725                this.last_request_timestamp = Instant::now();
1726                this.zeta.update(cx, |zeta, cx| {
1727                    zeta.request_completion(
1728                        project.as_ref(),
1729                        &buffer,
1730                        position,
1731                        can_collect_data,
1732                        cx,
1733                    )
1734                })
1735            });
1736
1737            let completion = match completion_request {
1738                Ok(completion_request) => {
1739                    let completion_request = completion_request.await;
1740                    completion_request.map(|c| {
1741                        c.map(|completion| CurrentEditPrediction {
1742                            buffer_id: buffer.entity_id(),
1743                            completion,
1744                        })
1745                    })
1746                }
1747                Err(error) => Err(error),
1748            };
1749            let Some(new_completion) = completion
1750                .context("edit prediction failed")
1751                .log_err()
1752                .flatten()
1753            else {
1754                this.update(cx, |this, cx| {
1755                    if this.pending_completions[0].id == pending_completion_id {
1756                        this.pending_completions.remove(0);
1757                    } else {
1758                        this.pending_completions.clear();
1759                    }
1760
1761                    cx.notify();
1762                })
1763                .ok();
1764                return;
1765            };
1766
1767            this.update(cx, |this, cx| {
1768                if this.pending_completions[0].id == pending_completion_id {
1769                    this.pending_completions.remove(0);
1770                } else {
1771                    this.pending_completions.clear();
1772                }
1773
1774                if let Some(old_completion) = this.current_completion.as_ref() {
1775                    let snapshot = buffer.read(cx).snapshot();
1776                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1777                        this.zeta.update(cx, |zeta, cx| {
1778                            zeta.completion_shown(&new_completion.completion, cx);
1779                        });
1780                        this.current_completion = Some(new_completion);
1781                    }
1782                } else {
1783                    this.zeta.update(cx, |zeta, cx| {
1784                        zeta.completion_shown(&new_completion.completion, cx);
1785                    });
1786                    this.current_completion = Some(new_completion);
1787                }
1788
1789                cx.notify();
1790            })
1791            .ok();
1792        });
1793
1794        // We always maintain at most two pending completions. When we already
1795        // have two, we replace the newest one.
1796        if self.pending_completions.len() <= 1 {
1797            self.pending_completions.push(PendingCompletion {
1798                id: pending_completion_id,
1799                _task: task,
1800            });
1801        } else if self.pending_completions.len() == 2 {
1802            self.pending_completions.pop();
1803            self.pending_completions.push(PendingCompletion {
1804                id: pending_completion_id,
1805                _task: task,
1806            });
1807        }
1808    }
1809
1810    fn cycle(
1811        &mut self,
1812        _buffer: Entity<Buffer>,
1813        _cursor_position: language::Anchor,
1814        _direction: edit_prediction::Direction,
1815        _cx: &mut Context<Self>,
1816    ) {
1817        // Right now we don't support cycling.
1818    }
1819
1820    fn accept(&mut self, cx: &mut Context<Self>) {
1821        let completion_id = self
1822            .current_completion
1823            .as_ref()
1824            .map(|completion| completion.completion.id);
1825        if let Some(completion_id) = completion_id {
1826            self.zeta
1827                .update(cx, |zeta, cx| {
1828                    zeta.accept_edit_prediction(completion_id, cx)
1829                })
1830                .detach();
1831        }
1832        self.pending_completions.clear();
1833    }
1834
1835    fn discard(&mut self, _cx: &mut Context<Self>) {
1836        self.pending_completions.clear();
1837        self.current_completion.take();
1838    }
1839
1840    fn suggest(
1841        &mut self,
1842        buffer: &Entity<Buffer>,
1843        cursor_position: language::Anchor,
1844        cx: &mut Context<Self>,
1845    ) -> Option<edit_prediction::EditPrediction> {
1846        let CurrentEditPrediction {
1847            buffer_id,
1848            completion,
1849            ..
1850        } = self.current_completion.as_mut()?;
1851
1852        // Invalidate previous completion if it was generated for a different buffer.
1853        if *buffer_id != buffer.entity_id() {
1854            self.current_completion.take();
1855            return None;
1856        }
1857
1858        let buffer = buffer.read(cx);
1859        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1860            self.current_completion.take();
1861            return None;
1862        };
1863
1864        let cursor_row = cursor_position.to_point(buffer).row;
1865        let (closest_edit_ix, (closest_edit_range, _)) =
1866            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1867                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1868                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1869                cmp::min(distance_from_start, distance_from_end)
1870            })?;
1871
1872        let mut edit_start_ix = closest_edit_ix;
1873        for (range, _) in edits[..edit_start_ix].iter().rev() {
1874            let distance_from_closest_edit =
1875                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1876            if distance_from_closest_edit <= 1 {
1877                edit_start_ix -= 1;
1878            } else {
1879                break;
1880            }
1881        }
1882
1883        let mut edit_end_ix = closest_edit_ix + 1;
1884        for (range, _) in &edits[edit_end_ix..] {
1885            let distance_from_closest_edit =
1886                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1887            if distance_from_closest_edit <= 1 {
1888                edit_end_ix += 1;
1889            } else {
1890                break;
1891            }
1892        }
1893
1894        Some(edit_prediction::EditPrediction {
1895            id: Some(completion.id.to_string().into()),
1896            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1897            edit_preview: Some(completion.edit_preview.clone()),
1898        })
1899    }
1900}
1901
1902fn tokens_for_bytes(bytes: usize) -> usize {
1903    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1904    /// intentionally low to err on the side of underestimating limits.
1905    const BYTES_PER_TOKEN_GUESS: usize = 3;
1906    bytes / BYTES_PER_TOKEN_GUESS
1907}
1908
1909#[cfg(test)]
1910mod tests {
1911    use client::UserStore;
1912    use client::test::FakeServer;
1913    use clock::FakeSystemClock;
1914    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1915    use gpui::TestAppContext;
1916    use http_client::FakeHttpClient;
1917    use indoc::indoc;
1918    use language::Point;
1919    use settings::SettingsStore;
1920
1921    use super::*;
1922
1923    #[gpui::test]
1924    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1925        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1926        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1927            to_completion_edits(
1928                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1929                &buffer,
1930                cx,
1931            )
1932            .into()
1933        });
1934
1935        let edit_preview = cx
1936            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1937            .await;
1938
1939        let completion = EditPrediction {
1940            edits,
1941            edit_preview,
1942            path: Path::new("").into(),
1943            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1944            id: EditPredictionId(Uuid::new_v4()),
1945            excerpt_range: 0..0,
1946            cursor_offset: 0,
1947            input_outline: "".into(),
1948            input_events: "".into(),
1949            input_excerpt: "".into(),
1950            output_excerpt: "".into(),
1951            buffer_snapshotted_at: Instant::now(),
1952            response_received_at: Instant::now(),
1953        };
1954
1955        cx.update(|cx| {
1956            assert_eq!(
1957                from_completion_edits(
1958                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1959                    &buffer,
1960                    cx
1961                ),
1962                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1963            );
1964
1965            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1966            assert_eq!(
1967                from_completion_edits(
1968                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1969                    &buffer,
1970                    cx
1971                ),
1972                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1973            );
1974
1975            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1976            assert_eq!(
1977                from_completion_edits(
1978                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1979                    &buffer,
1980                    cx
1981                ),
1982                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1983            );
1984
1985            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1986            assert_eq!(
1987                from_completion_edits(
1988                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1989                    &buffer,
1990                    cx
1991                ),
1992                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1993            );
1994
1995            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1996            assert_eq!(
1997                from_completion_edits(
1998                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1999                    &buffer,
2000                    cx
2001                ),
2002                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2003            );
2004
2005            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2006            assert_eq!(
2007                from_completion_edits(
2008                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2009                    &buffer,
2010                    cx
2011                ),
2012                vec![(9..11, "".to_string())]
2013            );
2014
2015            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2016            assert_eq!(
2017                from_completion_edits(
2018                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2019                    &buffer,
2020                    cx
2021                ),
2022                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2023            );
2024
2025            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2026            assert_eq!(
2027                from_completion_edits(
2028                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2029                    &buffer,
2030                    cx
2031                ),
2032                vec![(4..4, "M".to_string())]
2033            );
2034
2035            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2036            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2037        })
2038    }
2039
2040    #[gpui::test]
2041    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2042        cx.update(|cx| {
2043            let settings_store = SettingsStore::test(cx);
2044            cx.set_global(settings_store);
2045            client::init_settings(cx);
2046        });
2047
2048        let edits = edits_for_prediction(
2049            indoc! {"
2050                fn main() {
2051                    let word_1 = \"lorem\";
2052                    let range = word.len()..word.len();
2053                }
2054            "},
2055            indoc! {"
2056                <|editable_region_start|>
2057                fn main() {
2058                    let word_1 = \"lorem\";
2059                    let range = word_1.len()..word_1.len();
2060                }
2061
2062                <|editable_region_end|>
2063            "},
2064            cx,
2065        )
2066        .await;
2067        assert_eq!(
2068            edits,
2069            [
2070                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2071                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2072            ]
2073        );
2074
2075        let edits = edits_for_prediction(
2076            indoc! {"
2077                fn main() {
2078                    let story = \"the quick\"
2079                }
2080            "},
2081            indoc! {"
2082                <|editable_region_start|>
2083                fn main() {
2084                    let story = \"the quick brown fox jumps over the lazy dog\";
2085                }
2086
2087                <|editable_region_end|>
2088            "},
2089            cx,
2090        )
2091        .await;
2092        assert_eq!(
2093            edits,
2094            [
2095                (
2096                    Point::new(1, 26)..Point::new(1, 26),
2097                    " brown fox jumps over the lazy dog".to_string()
2098                ),
2099                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2100            ]
2101        );
2102    }
2103
2104    #[gpui::test]
2105    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2106        cx.update(|cx| {
2107            let settings_store = SettingsStore::test(cx);
2108            cx.set_global(settings_store);
2109            client::init_settings(cx);
2110        });
2111
2112        let buffer_content = "lorem\n";
2113        let completion_response = indoc! {"
2114            ```animals.js
2115            <|start_of_file|>
2116            <|editable_region_start|>
2117            lorem
2118            ipsum
2119            <|editable_region_end|>
2120            ```"};
2121
2122        let http_client = FakeHttpClient::create(move |req| async move {
2123            match (req.method(), req.uri().path()) {
2124                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2125                    .status(200)
2126                    .body(
2127                        serde_json::to_string(&CreateLlmTokenResponse {
2128                            token: LlmToken("the-llm-token".to_string()),
2129                        })
2130                        .unwrap()
2131                        .into(),
2132                    )
2133                    .unwrap()),
2134                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2135                    .status(200)
2136                    .body(
2137                        serde_json::to_string(&PredictEditsResponse {
2138                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2139                                .unwrap(),
2140                            output_excerpt: completion_response.to_string(),
2141                        })
2142                        .unwrap()
2143                        .into(),
2144                    )
2145                    .unwrap()),
2146                _ => Ok(http_client::Response::builder()
2147                    .status(404)
2148                    .body("Not Found".into())
2149                    .unwrap()),
2150            }
2151        });
2152
2153        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2154        cx.update(|cx| {
2155            RefreshLlmTokenListener::register(client.clone(), cx);
2156        });
2157        // Construct the fake server to authenticate.
2158        let _server = FakeServer::for_client(42, &client, cx).await;
2159        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2160        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2161
2162        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2163        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2164        let completion_task = zeta.update(cx, |zeta, cx| {
2165            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2166        });
2167
2168        let completion = completion_task.await.unwrap().unwrap();
2169        buffer.update(cx, |buffer, cx| {
2170            buffer.edit(completion.edits.iter().cloned(), None, cx)
2171        });
2172        assert_eq!(
2173            buffer.read_with(cx, |buffer, _| buffer.text()),
2174            "lorem\nipsum"
2175        );
2176    }
2177
2178    async fn edits_for_prediction(
2179        buffer_content: &str,
2180        completion_response: &str,
2181        cx: &mut TestAppContext,
2182    ) -> Vec<(Range<Point>, String)> {
2183        let completion_response = completion_response.to_string();
2184        let http_client = FakeHttpClient::create(move |req| {
2185            let completion = completion_response.clone();
2186            async move {
2187                match (req.method(), req.uri().path()) {
2188                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2189                        .status(200)
2190                        .body(
2191                            serde_json::to_string(&CreateLlmTokenResponse {
2192                                token: LlmToken("the-llm-token".to_string()),
2193                            })
2194                            .unwrap()
2195                            .into(),
2196                        )
2197                        .unwrap()),
2198                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2199                        .status(200)
2200                        .body(
2201                            serde_json::to_string(&PredictEditsResponse {
2202                                request_id: Uuid::new_v4(),
2203                                output_excerpt: completion,
2204                            })
2205                            .unwrap()
2206                            .into(),
2207                        )
2208                        .unwrap()),
2209                    _ => Ok(http_client::Response::builder()
2210                        .status(404)
2211                        .body("Not Found".into())
2212                        .unwrap()),
2213                }
2214            }
2215        });
2216
2217        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2218        cx.update(|cx| {
2219            RefreshLlmTokenListener::register(client.clone(), cx);
2220        });
2221        // Construct the fake server to authenticate.
2222        let _server = FakeServer::for_client(42, &client, cx).await;
2223        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2224        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2225
2226        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2227        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2228        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2229        let completion_task = zeta.update(cx, |zeta, cx| {
2230            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2231        });
2232
2233        let completion = completion_task.await.unwrap().unwrap();
2234        completion
2235            .edits
2236            .iter()
2237            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2238            .collect::<Vec<_>>()
2239    }
2240
2241    fn to_completion_edits(
2242        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2243        buffer: &Entity<Buffer>,
2244        cx: &App,
2245    ) -> Vec<(Range<Anchor>, String)> {
2246        let buffer = buffer.read(cx);
2247        iterator
2248            .into_iter()
2249            .map(|(range, text)| {
2250                (
2251                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2252                    text,
2253                )
2254            })
2255            .collect()
2256    }
2257
2258    fn from_completion_edits(
2259        editor_edits: &[(Range<Anchor>, String)],
2260        buffer: &Entity<Buffer>,
2261        cx: &App,
2262    ) -> Vec<(Range<usize>, String)> {
2263        let buffer = buffer.read(cx);
2264        editor_edits
2265            .iter()
2266            .map(|(range, text)| {
2267                (
2268                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2269                    text.clone(),
2270                )
2271            })
2272            .collect()
2273    }
2274
2275    #[ctor::ctor]
2276    fn init_logger() {
2277        zlog::init_test();
2278    }
2279}