zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9use arrayvec::ArrayVec;
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction::DataCollectionState;
  13pub use init::*;
  14use license_detection::LicenseDetectionWatcher;
  15use project::git_store::Repository;
  16pub use rate_completion_modal::*;
  17
  18use anyhow::{Context as _, Result, anyhow};
  19use client::{Client, EditPredictionUsage, UserStore};
  20use cloud_llm_client::{
  21    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  22    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
  23    ZED_VERSION_HEADER_NAME,
  24};
  25use collections::{HashMap, HashSet, VecDeque};
  26use futures::AsyncReadExt;
  27use gpui::{
  28    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  29    Subscription, Task, WeakEntity, actions,
  30};
  31use http_client::{AsyncBody, HttpClient, Method, Request, Response};
  32use input_excerpt::excerpt_for_cursor_position;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  35};
  36use language_model::{LlmApiToken, RefreshLlmTokenListener};
  37use project::{Project, ProjectEntryId, ProjectPath};
  38use release_channel::AppVersion;
  39use settings::WorktreeId;
  40use std::str::FromStr;
  41use std::{
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::EditPredictionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::Workspace;
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  58use worktree::Worktree;
  59
  60const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
  61const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
  62const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
  63const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
  64const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  65const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  66
  67const MAX_CONTEXT_TOKENS: usize = 150;
  68const MAX_REWRITE_TOKENS: usize = 350;
  69const MAX_EVENT_TOKENS: usize = 500;
  70const MAX_DIAGNOSTIC_GROUPS: usize = 10;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75/// Maximum number of recent files to track.
  76const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
  77
  78/// Maximum file path length to include in recent files list.
  79const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
  80
  81/// Maximum number of edit predictions to store for feedback.
  82const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
  83
  84actions!(
  85    edit_prediction,
  86    [
  87        /// Clears the edit prediction history.
  88        ClearHistory
  89    ]
  90);
  91
  92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  93pub struct EditPredictionId(Uuid);
  94
  95impl From<EditPredictionId> for gpui::ElementId {
  96    fn from(value: EditPredictionId) -> Self {
  97        gpui::ElementId::Uuid(value.0)
  98    }
  99}
 100
 101impl std::fmt::Display for EditPredictionId {
 102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 103        write!(f, "{}", self.0)
 104    }
 105}
 106
 107struct ZedPredictUpsell;
 108
 109impl Dismissable for ZedPredictUpsell {
 110    const KEY: &'static str = "dismissed-edit-predict-upsell";
 111
 112    fn dismissed() -> bool {
 113        // To make this backwards compatible with older versions of Zed, we
 114        // check if the user has seen the previous Edit Prediction Onboarding
 115        // before, by checking the data collection choice which was written to
 116        // the database once the user clicked on "Accept and Enable"
 117        if KEY_VALUE_STORE
 118            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 119            .log_err()
 120            .is_some_and(|s| s.is_some())
 121        {
 122            return true;
 123        }
 124
 125        KEY_VALUE_STORE
 126            .read_kvp(Self::KEY)
 127            .log_err()
 128            .is_some_and(|s| s.is_some())
 129    }
 130}
 131
 132pub fn should_show_upsell_modal() -> bool {
 133    !ZedPredictUpsell::dismissed()
 134}
 135
 136#[derive(Clone)]
 137struct ZetaGlobal(Entity<Zeta>);
 138
 139impl Global for ZetaGlobal {}
 140
 141#[derive(Clone)]
 142pub struct EditPrediction {
 143    id: EditPredictionId,
 144    path: Arc<Path>,
 145    excerpt_range: Range<usize>,
 146    cursor_offset: usize,
 147    edits: Arc<[(Range<Anchor>, String)]>,
 148    snapshot: BufferSnapshot,
 149    edit_preview: EditPreview,
 150    input_outline: Arc<str>,
 151    input_events: Arc<str>,
 152    input_excerpt: Arc<str>,
 153    output_excerpt: Arc<str>,
 154    buffer_snapshotted_at: Instant,
 155    response_received_at: Instant,
 156}
 157
 158impl EditPrediction {
 159    fn latency(&self) -> Duration {
 160        self.response_received_at
 161            .duration_since(self.buffer_snapshotted_at)
 162    }
 163
 164    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 165        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 166    }
 167}
 168
 169fn interpolate(
 170    old_snapshot: &BufferSnapshot,
 171    new_snapshot: &BufferSnapshot,
 172    current_edits: Arc<[(Range<Anchor>, String)]>,
 173) -> Option<Vec<(Range<Anchor>, String)>> {
 174    let mut edits = Vec::new();
 175
 176    let mut model_edits = current_edits.iter().peekable();
 177    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 178        while let Some((model_old_range, _)) = model_edits.peek() {
 179            let model_old_range = model_old_range.to_offset(old_snapshot);
 180            if model_old_range.end < user_edit.old.start {
 181                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 182                edits.push((model_old_range.clone(), model_new_text.clone()));
 183            } else {
 184                break;
 185            }
 186        }
 187
 188        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 189            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 190            if user_edit.old == model_old_offset_range {
 191                let user_new_text = new_snapshot
 192                    .text_for_range(user_edit.new.clone())
 193                    .collect::<String>();
 194
 195                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 196                    if !model_suffix.is_empty() {
 197                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 198                        edits.push((anchor..anchor, model_suffix.to_string()));
 199                    }
 200
 201                    model_edits.next();
 202                    continue;
 203                }
 204            }
 205        }
 206
 207        return None;
 208    }
 209
 210    edits.extend(model_edits.cloned());
 211
 212    if edits.is_empty() { None } else { Some(edits) }
 213}
 214
 215impl std::fmt::Debug for EditPrediction {
 216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 217        f.debug_struct("EditPrediction")
 218            .field("id", &self.id)
 219            .field("path", &self.path)
 220            .field("edits", &self.edits)
 221            .finish_non_exhaustive()
 222    }
 223}
 224
 225pub struct Zeta {
 226    workspace: WeakEntity<Workspace>,
 227    client: Arc<Client>,
 228    events: VecDeque<Event>,
 229    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 230    shown_completions: VecDeque<EditPrediction>,
 231    rated_completions: HashSet<EditPredictionId>,
 232    data_collection_choice: Entity<DataCollectionChoice>,
 233    llm_token: LlmApiToken,
 234    _llm_token_subscription: Subscription,
 235    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 236    update_required: bool,
 237    user_store: Entity<UserStore>,
 238    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 239    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 240}
 241
 242impl Zeta {
 243    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 244        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 245    }
 246
 247    pub fn register(
 248        workspace: Option<Entity<Workspace>>,
 249        worktree: Option<Entity<Worktree>>,
 250        client: Arc<Client>,
 251        user_store: Entity<UserStore>,
 252        cx: &mut App,
 253    ) -> Entity<Self> {
 254        let this = Self::global(cx).unwrap_or_else(|| {
 255            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 256            cx.set_global(ZetaGlobal(entity.clone()));
 257            entity
 258        });
 259
 260        this.update(cx, move |this, cx| {
 261            if let Some(worktree) = worktree {
 262                let worktree_id = worktree.read(cx).id();
 263                this.license_detection_watchers
 264                    .entry(worktree_id)
 265                    .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
 266            }
 267        });
 268
 269        this
 270    }
 271
 272    pub fn clear_history(&mut self) {
 273        self.events.clear();
 274    }
 275
 276    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 277        self.user_store.read(cx).edit_prediction_usage()
 278    }
 279
 280    fn new(
 281        workspace: Option<Entity<Workspace>>,
 282        client: Arc<Client>,
 283        user_store: Entity<UserStore>,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 287
 288        let data_collection_choice = Self::load_data_collection_choices();
 289        let data_collection_choice = cx.new(|_| data_collection_choice);
 290
 291        if let Some(workspace) = &workspace {
 292            cx.subscribe(
 293                &workspace.read(cx).project().clone(),
 294                |this, _workspace, event, _cx| match event {
 295                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
 296                        this.push_recent_project_entry(*project_entry_id)
 297                    }
 298                    _ => {}
 299                },
 300            )
 301            .detach();
 302        }
 303
 304        Self {
 305            workspace: workspace.map_or_else(
 306                || WeakEntity::new_invalid(),
 307                |workspace| workspace.downgrade(),
 308            ),
 309            client,
 310            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
 311            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
 312            rated_completions: HashSet::default(),
 313            registered_buffers: HashMap::default(),
 314            data_collection_choice,
 315            llm_token: LlmApiToken::default(),
 316            _llm_token_subscription: cx.subscribe(
 317                &refresh_llm_token_listener,
 318                |this, _listener, _event, cx| {
 319                    let client = this.client.clone();
 320                    let llm_token = this.llm_token.clone();
 321                    cx.spawn(async move |_this, _cx| {
 322                        llm_token.refresh(&client).await?;
 323                        anyhow::Ok(())
 324                    })
 325                    .detach_and_log_err(cx);
 326                },
 327            ),
 328            update_required: false,
 329            license_detection_watchers: HashMap::default(),
 330            user_store,
 331            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
 332        }
 333    }
 334
 335    fn push_event(&mut self, event: Event) {
 336        if let Some(Event::BufferChange {
 337            new_snapshot: last_new_snapshot,
 338            timestamp: last_timestamp,
 339            ..
 340        }) = self.events.back_mut()
 341        {
 342            // Coalesce edits for the same buffer when they happen one after the other.
 343            let Event::BufferChange {
 344                old_snapshot,
 345                new_snapshot,
 346                timestamp,
 347            } = &event;
 348
 349            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 350                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 351                && old_snapshot.version == last_new_snapshot.version
 352            {
 353                *last_new_snapshot = new_snapshot.clone();
 354                *last_timestamp = *timestamp;
 355                return;
 356            }
 357        }
 358
 359        if self.events.len() >= MAX_EVENT_COUNT {
 360            // These are halved instead of popping to improve prompt caching.
 361            self.events.drain(..MAX_EVENT_COUNT / 2);
 362        }
 363
 364        self.events.push_back(event);
 365    }
 366
 367    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 368        let buffer_id = buffer.entity_id();
 369        let weak_buffer = buffer.downgrade();
 370
 371        if let std::collections::hash_map::Entry::Vacant(entry) =
 372            self.registered_buffers.entry(buffer_id)
 373        {
 374            let snapshot = buffer.read(cx).snapshot();
 375
 376            entry.insert(RegisteredBuffer {
 377                snapshot,
 378                _subscriptions: [
 379                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 380                        this.handle_buffer_event(buffer, event, cx);
 381                    }),
 382                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 383                        this.registered_buffers.remove(&weak_buffer.entity_id());
 384                    }),
 385                ],
 386            });
 387        };
 388    }
 389
 390    fn handle_buffer_event(
 391        &mut self,
 392        buffer: Entity<Buffer>,
 393        event: &language::BufferEvent,
 394        cx: &mut Context<Self>,
 395    ) {
 396        if let language::BufferEvent::Edited = event {
 397            self.report_changes_for_buffer(&buffer, cx);
 398        }
 399    }
 400
 401    fn request_completion_impl<F, R>(
 402        &mut self,
 403        workspace: Option<Entity<Workspace>>,
 404        project: Option<&Entity<Project>>,
 405        buffer: &Entity<Buffer>,
 406        cursor: language::Anchor,
 407        can_collect_data: CanCollectData,
 408        cx: &mut Context<Self>,
 409        perform_predict_edits: F,
 410    ) -> Task<Result<Option<EditPrediction>>>
 411    where
 412        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 413        R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>>
 414            + Send
 415            + 'static,
 416    {
 417        let buffer = buffer.clone();
 418        let buffer_snapshotted_at = Instant::now();
 419        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 420        let zeta = cx.entity();
 421        let events = self.events.clone();
 422        let client = self.client.clone();
 423        let llm_token = self.llm_token.clone();
 424        let app_version = AppVersion::global(cx);
 425
 426        let git_info = if matches!(can_collect_data, CanCollectData(true)) {
 427            self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
 428        } else {
 429            None
 430        };
 431
 432        let full_path: Arc<Path> = snapshot
 433            .file()
 434            .map(|f| Arc::from(f.full_path(cx).as_path()))
 435            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 436        let full_path_str = full_path.to_string_lossy().to_string();
 437        let cursor_point = cursor.to_point(&snapshot);
 438        let cursor_offset = cursor_point.to_offset(&snapshot);
 439        let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
 440        let gather_task = gather_context(
 441            project,
 442            full_path_str,
 443            &snapshot,
 444            cursor_point,
 445            make_events_prompt,
 446            can_collect_data,
 447            git_info,
 448            cx,
 449        );
 450
 451        cx.spawn(async move |this, cx| {
 452            let GatherContextOutput {
 453                body,
 454                editable_range,
 455            } = gather_task.await?;
 456            let done_gathering_context_at = Instant::now();
 457
 458            log::debug!(
 459                "Events:\n{}\nExcerpt:\n{:?}",
 460                body.input_events,
 461                body.input_excerpt
 462            );
 463
 464            let input_outline = body.outline.clone().unwrap_or_default();
 465            let input_events = body.input_events.clone();
 466            let input_excerpt = body.input_excerpt.clone();
 467
 468            let response = perform_predict_edits(PerformPredictEditsParams {
 469                client,
 470                llm_token,
 471                app_version,
 472                body,
 473            })
 474            .await;
 475            let (response, usage) = match response {
 476                Ok(response) => response,
 477                Err(err) => {
 478                    if err.is::<ZedUpdateRequiredError>() {
 479                        cx.update(|cx| {
 480                            zeta.update(cx, |zeta, _cx| {
 481                                zeta.update_required = true;
 482                            });
 483
 484                            if let Some(workspace) = workspace {
 485                                workspace.update(cx, |workspace, cx| {
 486                                    workspace.show_notification(
 487                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 488                                        cx,
 489                                        |cx| {
 490                                            cx.new(|cx| {
 491                                                ErrorMessagePrompt::new(err.to_string(), cx)
 492                                                    .with_link_button(
 493                                                        "Update Zed",
 494                                                        "https://zed.dev/releases",
 495                                                    )
 496                                            })
 497                                        },
 498                                    );
 499                                });
 500                            }
 501                        })
 502                        .ok();
 503                    }
 504
 505                    return Err(err);
 506                }
 507            };
 508
 509            let received_response_at = Instant::now();
 510            log::debug!("completion response: {}", &response.output_excerpt);
 511
 512            if let Some(usage) = usage {
 513                this.update(cx, |this, cx| {
 514                    this.user_store.update(cx, |user_store, cx| {
 515                        user_store.update_edit_prediction_usage(usage, cx);
 516                    });
 517                })
 518                .ok();
 519            }
 520
 521            let edit_prediction = Self::process_completion_response(
 522                response,
 523                buffer,
 524                &snapshot,
 525                editable_range,
 526                cursor_offset,
 527                full_path,
 528                input_outline,
 529                input_events,
 530                input_excerpt,
 531                buffer_snapshotted_at,
 532                cx,
 533            )
 534            .await;
 535
 536            let finished_at = Instant::now();
 537
 538            // record latency for ~1% of requests
 539            if rand::random::<u8>() <= 2 {
 540                telemetry::event!(
 541                    "Edit Prediction Request",
 542                    context_latency = done_gathering_context_at
 543                        .duration_since(buffer_snapshotted_at)
 544                        .as_millis(),
 545                    request_latency = received_response_at
 546                        .duration_since(done_gathering_context_at)
 547                        .as_millis(),
 548                    process_latency = finished_at.duration_since(received_response_at).as_millis()
 549                );
 550            }
 551
 552            edit_prediction
 553        })
 554    }
 555
 556    // Generates several example completions of various states to fill the Zeta completion modal
 557    #[cfg(any(test, feature = "test-support"))]
 558    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 559        use language::Point;
 560
 561        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 562            And maybe a short line
 563
 564            Then a few lines
 565
 566            and then another
 567            "#};
 568
 569        let project = None;
 570        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 571        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 572
 573        let completion_tasks = vec![
 574            self.fake_completion(
 575                project,
 576                &buffer,
 577                position,
 578                PredictEditsResponse {
 579                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 580                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 581a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 582[here's an edit]
 583And maybe a short line
 584Then a few lines
 585and then another
 586{EDITABLE_REGION_END_MARKER}
 587                        ", ),
 588                },
 589                cx,
 590            ),
 591            self.fake_completion(
 592                project,
 593                &buffer,
 594                position,
 595                PredictEditsResponse {
 596                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 597                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 598a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 599And maybe a short line
 600[and another edit]
 601Then a few lines
 602and then another
 603{EDITABLE_REGION_END_MARKER}
 604                        "#),
 605                },
 606                cx,
 607            ),
 608            self.fake_completion(
 609                project,
 610                &buffer,
 611                position,
 612                PredictEditsResponse {
 613                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 614                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 615a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 616And maybe a short line
 617
 618Then a few lines
 619
 620and then another
 621{EDITABLE_REGION_END_MARKER}
 622                        "#),
 623                },
 624                cx,
 625            ),
 626            self.fake_completion(
 627                project,
 628                &buffer,
 629                position,
 630                PredictEditsResponse {
 631                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 632                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 633a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 634And maybe a short line
 635
 636Then a few lines
 637
 638and then another
 639{EDITABLE_REGION_END_MARKER}
 640                        "#),
 641                },
 642                cx,
 643            ),
 644            self.fake_completion(
 645                project,
 646                &buffer,
 647                position,
 648                PredictEditsResponse {
 649                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 650                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 652And maybe a short line
 653Then a few lines
 654[a third completion]
 655and then another
 656{EDITABLE_REGION_END_MARKER}
 657                        "#),
 658                },
 659                cx,
 660            ),
 661            self.fake_completion(
 662                project,
 663                &buffer,
 664                position,
 665                PredictEditsResponse {
 666                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 667                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 668a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 669And maybe a short line
 670and then another
 671[fourth completion example]
 672{EDITABLE_REGION_END_MARKER}
 673                        "#),
 674                },
 675                cx,
 676            ),
 677            self.fake_completion(
 678                project,
 679                &buffer,
 680                position,
 681                PredictEditsResponse {
 682                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 683                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 684a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 685And maybe a short line
 686Then a few lines
 687and then another
 688[fifth and final completion]
 689{EDITABLE_REGION_END_MARKER}
 690                        "#),
 691                },
 692                cx,
 693            ),
 694        ];
 695
 696        cx.spawn(async move |zeta, cx| {
 697            for task in completion_tasks {
 698                task.await.unwrap();
 699            }
 700
 701            zeta.update(cx, |zeta, _cx| {
 702                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 703                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 704            })
 705            .ok();
 706        })
 707    }
 708
 709    #[cfg(any(test, feature = "test-support"))]
 710    pub fn fake_completion(
 711        &mut self,
 712        project: Option<&Entity<Project>>,
 713        buffer: &Entity<Buffer>,
 714        position: language::Anchor,
 715        response: PredictEditsResponse,
 716        cx: &mut Context<Self>,
 717    ) -> Task<Result<Option<EditPrediction>>> {
 718        use std::future::ready;
 719
 720        self.request_completion_impl(
 721            None,
 722            project,
 723            buffer,
 724            position,
 725            CanCollectData(false),
 726            cx,
 727            |_params| ready(Ok((response, None))),
 728        )
 729    }
 730
 731    pub fn request_completion(
 732        &mut self,
 733        project: Option<&Entity<Project>>,
 734        buffer: &Entity<Buffer>,
 735        position: language::Anchor,
 736        can_collect_data: CanCollectData,
 737        cx: &mut Context<Self>,
 738    ) -> Task<Result<Option<EditPrediction>>> {
 739        self.request_completion_impl(
 740            self.workspace.upgrade(),
 741            project,
 742            buffer,
 743            position,
 744            can_collect_data,
 745            cx,
 746            Self::perform_predict_edits,
 747        )
 748    }
 749
 750    pub fn perform_predict_edits(
 751        params: PerformPredictEditsParams,
 752    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
 753        async move {
 754            let PerformPredictEditsParams {
 755                client,
 756                llm_token,
 757                app_version,
 758                body,
 759                ..
 760            } = params;
 761
 762            let http_client = client.http_client();
 763            let mut token = llm_token.acquire(&client).await?;
 764            let mut did_retry = false;
 765
 766            loop {
 767                let request_builder = http_client::Request::builder().method(Method::POST);
 768                let request_builder =
 769                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 770                        request_builder.uri(predict_edits_url)
 771                    } else {
 772                        request_builder.uri(
 773                            http_client
 774                                .build_zed_llm_url("/predict_edits/v2", &[])?
 775                                .as_ref(),
 776                        )
 777                    };
 778                let request = request_builder
 779                    .header("Content-Type", "application/json")
 780                    .header("Authorization", format!("Bearer {}", token))
 781                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 782                    .body(serde_json::to_string(&body)?.into())?;
 783
 784                let mut response = http_client.send(request).await?;
 785
 786                if let Some(minimum_required_version) = response
 787                    .headers()
 788                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 789                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 790                {
 791                    anyhow::ensure!(
 792                        app_version >= minimum_required_version,
 793                        ZedUpdateRequiredError {
 794                            minimum_version: minimum_required_version
 795                        }
 796                    );
 797                }
 798
 799                if response.status().is_success() {
 800                    let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 801
 802                    let mut body = String::new();
 803                    response.body_mut().read_to_string(&mut body).await?;
 804                    return Ok((serde_json::from_str(&body)?, usage));
 805                } else if !did_retry
 806                    && response
 807                        .headers()
 808                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 809                        .is_some()
 810                {
 811                    did_retry = true;
 812                    token = llm_token.refresh(&client).await?;
 813                } else {
 814                    let mut body = String::new();
 815                    response.body_mut().read_to_string(&mut body).await?;
 816                    anyhow::bail!(
 817                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 818                        response.status(),
 819                        body
 820                    );
 821                }
 822            }
 823        }
 824    }
 825
 826    fn accept_edit_prediction(
 827        &mut self,
 828        request_id: EditPredictionId,
 829        cx: &mut Context<Self>,
 830    ) -> Task<Result<()>> {
 831        let client = self.client.clone();
 832        let llm_token = self.llm_token.clone();
 833        let app_version = AppVersion::global(cx);
 834        cx.spawn(async move |this, cx| {
 835            let http_client = client.http_client();
 836            let mut response = llm_token_retry(&llm_token, &client, |token| {
 837                let request_builder = http_client::Request::builder().method(Method::POST);
 838                let request_builder =
 839                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 840                        request_builder.uri(accept_prediction_url)
 841                    } else {
 842                        request_builder.uri(
 843                            http_client
 844                                .build_zed_llm_url("/predict_edits/accept", &[])?
 845                                .as_ref(),
 846                        )
 847                    };
 848                Ok(request_builder
 849                    .header("Content-Type", "application/json")
 850                    .header("Authorization", format!("Bearer {}", token))
 851                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 852                    .body(
 853                        serde_json::to_string(&AcceptEditPredictionBody {
 854                            request_id: request_id.0,
 855                        })?
 856                        .into(),
 857                    )?)
 858            })
 859            .await?;
 860
 861            if let Some(minimum_required_version) = response
 862                .headers()
 863                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 864                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 865                && app_version < minimum_required_version
 866            {
 867                return Err(anyhow!(ZedUpdateRequiredError {
 868                    minimum_version: minimum_required_version
 869                }));
 870            }
 871
 872            if response.status().is_success() {
 873                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
 874                    this.update(cx, |this, cx| {
 875                        this.user_store.update(cx, |user_store, cx| {
 876                            user_store.update_edit_prediction_usage(usage, cx);
 877                        });
 878                    })?;
 879                }
 880
 881                Ok(())
 882            } else {
 883                let mut body = String::new();
 884                response.body_mut().read_to_string(&mut body).await?;
 885                Err(anyhow!(
 886                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
 887                    response.status(),
 888                    body
 889                ))
 890            }
 891        })
 892    }
 893
 894    fn process_completion_response(
 895        prediction_response: PredictEditsResponse,
 896        buffer: Entity<Buffer>,
 897        snapshot: &BufferSnapshot,
 898        editable_range: Range<usize>,
 899        cursor_offset: usize,
 900        path: Arc<Path>,
 901        input_outline: String,
 902        input_events: String,
 903        input_excerpt: String,
 904        buffer_snapshotted_at: Instant,
 905        cx: &AsyncApp,
 906    ) -> Task<Result<Option<EditPrediction>>> {
 907        let snapshot = snapshot.clone();
 908        let request_id = prediction_response.request_id;
 909        let output_excerpt = prediction_response.output_excerpt;
 910        cx.spawn(async move |cx| {
 911            let output_excerpt: Arc<str> = output_excerpt.into();
 912
 913            let edits: Arc<[(Range<Anchor>, String)]> = cx
 914                .background_spawn({
 915                    let output_excerpt = output_excerpt.clone();
 916                    let editable_range = editable_range.clone();
 917                    let snapshot = snapshot.clone();
 918                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 919                })
 920                .await?
 921                .into();
 922
 923            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 924                let edits = edits.clone();
 925                |buffer, cx| {
 926                    let new_snapshot = buffer.snapshot();
 927                    let edits: Arc<[(Range<Anchor>, String)]> =
 928                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 929                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 930                }
 931            })?
 932            else {
 933                return anyhow::Ok(None);
 934            };
 935
 936            let edit_preview = edit_preview.await;
 937
 938            Ok(Some(EditPrediction {
 939                id: EditPredictionId(request_id),
 940                path,
 941                excerpt_range: editable_range,
 942                cursor_offset,
 943                edits,
 944                edit_preview,
 945                snapshot,
 946                input_outline: input_outline.into(),
 947                input_events: input_events.into(),
 948                input_excerpt: input_excerpt.into(),
 949                output_excerpt,
 950                buffer_snapshotted_at,
 951                response_received_at: Instant::now(),
 952            }))
 953        })
 954    }
 955
 956    fn parse_edits(
 957        output_excerpt: Arc<str>,
 958        editable_range: Range<usize>,
 959        snapshot: &BufferSnapshot,
 960    ) -> Result<Vec<(Range<Anchor>, String)>> {
 961        let content = output_excerpt.replace(CURSOR_MARKER, "");
 962
 963        let start_markers = content
 964            .match_indices(EDITABLE_REGION_START_MARKER)
 965            .collect::<Vec<_>>();
 966        anyhow::ensure!(
 967            start_markers.len() == 1,
 968            "expected exactly one start marker, found {}",
 969            start_markers.len()
 970        );
 971
 972        let end_markers = content
 973            .match_indices(EDITABLE_REGION_END_MARKER)
 974            .collect::<Vec<_>>();
 975        anyhow::ensure!(
 976            end_markers.len() == 1,
 977            "expected exactly one end marker, found {}",
 978            end_markers.len()
 979        );
 980
 981        let sof_markers = content
 982            .match_indices(START_OF_FILE_MARKER)
 983            .collect::<Vec<_>>();
 984        anyhow::ensure!(
 985            sof_markers.len() <= 1,
 986            "expected at most one start-of-file marker, found {}",
 987            sof_markers.len()
 988        );
 989
 990        let codefence_start = start_markers[0].0;
 991        let content = &content[codefence_start..];
 992
 993        let newline_ix = content.find('\n').context("could not find newline")?;
 994        let content = &content[newline_ix + 1..];
 995
 996        let codefence_end = content
 997            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 998            .context("could not find end marker")?;
 999        let new_text = &content[..codefence_end];
1000
1001        let old_text = snapshot
1002            .text_for_range(editable_range.clone())
1003            .collect::<String>();
1004
1005        Ok(Self::compute_edits(
1006            old_text,
1007            new_text,
1008            editable_range.start,
1009            snapshot,
1010        ))
1011    }
1012
1013    pub fn compute_edits(
1014        old_text: String,
1015        new_text: &str,
1016        offset: usize,
1017        snapshot: &BufferSnapshot,
1018    ) -> Vec<(Range<Anchor>, String)> {
1019        text_diff(&old_text, new_text)
1020            .into_iter()
1021            .map(|(mut old_range, new_text)| {
1022                old_range.start += offset;
1023                old_range.end += offset;
1024
1025                let prefix_len = common_prefix(
1026                    snapshot.chars_for_range(old_range.clone()),
1027                    new_text.chars(),
1028                );
1029                old_range.start += prefix_len;
1030
1031                let suffix_len = common_prefix(
1032                    snapshot.reversed_chars_for_range(old_range.clone()),
1033                    new_text[prefix_len..].chars().rev(),
1034                );
1035                old_range.end = old_range.end.saturating_sub(suffix_len);
1036
1037                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
1038                let range = if old_range.is_empty() {
1039                    let anchor = snapshot.anchor_after(old_range.start);
1040                    anchor..anchor
1041                } else {
1042                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
1043                };
1044                (range, new_text)
1045            })
1046            .collect()
1047    }
1048
1049    pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool {
1050        self.rated_completions.contains(&completion_id)
1051    }
1052
1053    pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
1054        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
1055            let completion = self.shown_completions.pop_back().unwrap();
1056            self.rated_completions.remove(&completion.id);
1057        }
1058        self.shown_completions.push_front(completion.clone());
1059        cx.notify();
1060    }
1061
1062    pub fn rate_completion(
1063        &mut self,
1064        completion: &EditPrediction,
1065        rating: EditPredictionRating,
1066        feedback: String,
1067        cx: &mut Context<Self>,
1068    ) {
1069        self.rated_completions.insert(completion.id);
1070        telemetry::event!(
1071            "Edit Prediction Rated",
1072            rating,
1073            input_events = completion.input_events,
1074            input_excerpt = completion.input_excerpt,
1075            input_outline = completion.input_outline,
1076            output_excerpt = completion.output_excerpt,
1077            feedback
1078        );
1079        self.client.telemetry().flush_events().detach();
1080        cx.notify();
1081    }
1082
1083    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1084        self.shown_completions.iter()
1085    }
1086
1087    pub fn shown_completions_len(&self) -> usize {
1088        self.shown_completions.len()
1089    }
1090
1091    fn report_changes_for_buffer(
1092        &mut self,
1093        buffer: &Entity<Buffer>,
1094        cx: &mut Context<Self>,
1095    ) -> BufferSnapshot {
1096        self.register_buffer(buffer, cx);
1097
1098        let registered_buffer = self
1099            .registered_buffers
1100            .get_mut(&buffer.entity_id())
1101            .unwrap();
1102        let new_snapshot = buffer.read(cx).snapshot();
1103
1104        if new_snapshot.version != registered_buffer.snapshot.version {
1105            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1106            self.push_event(Event::BufferChange {
1107                old_snapshot,
1108                new_snapshot: new_snapshot.clone(),
1109                timestamp: Instant::now(),
1110            });
1111        }
1112
1113        new_snapshot
1114    }
1115
1116    fn load_data_collection_choices() -> DataCollectionChoice {
1117        let choice = KEY_VALUE_STORE
1118            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1119            .log_err()
1120            .flatten();
1121
1122        match choice.as_deref() {
1123            Some("true") => DataCollectionChoice::Enabled,
1124            Some("false") => DataCollectionChoice::Disabled,
1125            Some(_) => {
1126                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1127                DataCollectionChoice::NotAnswered
1128            }
1129            None => DataCollectionChoice::NotAnswered,
1130        }
1131    }
1132
1133    fn gather_git_info(
1134        &mut self,
1135        project: Option<&Entity<Project>>,
1136        buffer_snapshotted_at: &Instant,
1137        snapshot: &BufferSnapshot,
1138        cx: &Context<Self>,
1139    ) -> Option<PredictEditsGitInfo> {
1140        let project = project?.read(cx);
1141        let file = snapshot.file()?;
1142        let project_path = ProjectPath::from_file(file.as_ref(), cx);
1143        let entry = project.entry_for_path(&project_path, cx)?;
1144        if !worktree_entry_eligible_for_collection(&entry) {
1145            return None;
1146        }
1147
1148        let git_store = project.git_store().read(cx);
1149        let (repository, _repo_path) =
1150            git_store.repository_and_path_for_project_path(&project_path, cx)?;
1151
1152        let repository = repository.read(cx);
1153        let head_sha = repository
1154            .head_commit
1155            .as_ref()
1156            .map(|head_commit| head_commit.sha.to_string());
1157        let remote_origin_url = repository.remote_origin_url.clone();
1158        let remote_upstream_url = repository.remote_upstream_url.clone();
1159        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
1160            return None;
1161        }
1162
1163        let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
1164
1165        Some(PredictEditsGitInfo {
1166            head_sha,
1167            remote_origin_url,
1168            remote_upstream_url,
1169            recent_files: Some(recent_files),
1170        })
1171    }
1172
1173    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
1174        let now = Instant::now();
1175        if let Some(existing_ix) = self
1176            .recent_project_entries
1177            .iter()
1178            .rposition(|(id, _)| *id == project_entry_id)
1179        {
1180            self.recent_project_entries.remove(existing_ix);
1181        }
1182        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
1183            self.recent_project_entries.pop_front();
1184        }
1185        self.recent_project_entries
1186            .push_back((project_entry_id, now));
1187    }
1188
1189    fn recent_files(
1190        &mut self,
1191        now: &Instant,
1192        repository: &Repository,
1193        cx: &Context<Self>,
1194    ) -> Vec<PredictEditsRecentFile> {
1195        let Ok(project) = self
1196            .workspace
1197            .read_with(cx, |workspace, _cx| workspace.project().clone())
1198        else {
1199            return Vec::new();
1200        };
1201        let mut results = Vec::new();
1202        for ix in (0..self.recent_project_entries.len()).rev() {
1203            let (entry_id, last_active_at) = &self.recent_project_entries[ix];
1204            if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
1205                && let worktree = worktree.read(cx)
1206                && let Some(entry) = worktree.entry_for_id(*entry_id)
1207                && worktree_entry_eligible_for_collection(entry)
1208            {
1209                let project_path = ProjectPath {
1210                    worktree_id: worktree.id(),
1211                    path: entry.path.clone(),
1212                };
1213                let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx)
1214                else {
1215                    // entry not removed since queries involving other repositories might occur later
1216                    continue;
1217                };
1218                let Some(repo_path_str) = repo_path.to_str() else {
1219                    // paths may not be valid UTF-8
1220                    self.recent_project_entries.remove(ix);
1221                    continue;
1222                };
1223                if repo_path_str.len() > MAX_RECENT_FILE_PATH_LENGTH {
1224                    self.recent_project_entries.remove(ix);
1225                    continue;
1226                }
1227                let Ok(active_to_now_ms) =
1228                    now.duration_since(*last_active_at).as_millis().try_into()
1229                else {
1230                    self.recent_project_entries.remove(ix);
1231                    continue;
1232                };
1233                results.push(PredictEditsRecentFile {
1234                    repo_path: repo_path_str.to_string(),
1235                    active_to_now_ms,
1236                });
1237            } else {
1238                self.recent_project_entries.remove(ix);
1239            }
1240        }
1241        results
1242    }
1243}
1244
1245fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
1246    entry.is_file()
1247        && entry.is_created()
1248        && !entry.is_ignored
1249        && !entry.is_private
1250        && !entry.is_external
1251        && !entry.is_fifo
1252}
1253
1254pub struct PerformPredictEditsParams {
1255    pub client: Arc<Client>,
1256    pub llm_token: LlmApiToken,
1257    pub app_version: SemanticVersion,
1258    pub body: PredictEditsBody,
1259}
1260
1261#[derive(Error, Debug)]
1262#[error(
1263    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1264)]
1265pub struct ZedUpdateRequiredError {
1266    minimum_version: SemanticVersion,
1267}
1268
1269fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1270    a.zip(b)
1271        .take_while(|(a, b)| a == b)
1272        .map(|(a, _)| a.len_utf8())
1273        .sum()
1274}
1275
1276pub struct GatherContextOutput {
1277    pub body: PredictEditsBody,
1278    pub editable_range: Range<usize>,
1279}
1280
1281pub fn gather_context(
1282    project: Option<&Entity<Project>>,
1283    full_path_str: String,
1284    snapshot: &BufferSnapshot,
1285    cursor_point: language::Point,
1286    make_events_prompt: impl FnOnce() -> String + Send + 'static,
1287    can_collect_data: CanCollectData,
1288    git_info: Option<PredictEditsGitInfo>,
1289    cx: &App,
1290) -> Task<Result<GatherContextOutput>> {
1291    let local_lsp_store =
1292        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
1293    let diagnostic_groups: Vec<(String, serde_json::Value)> =
1294        if matches!(can_collect_data, CanCollectData(true))
1295            && let Some(local_lsp_store) = local_lsp_store
1296        {
1297            snapshot
1298                .diagnostic_groups(None)
1299                .into_iter()
1300                .filter_map(|(language_server_id, diagnostic_group)| {
1301                    let language_server =
1302                        local_lsp_store.running_language_server_for_id(language_server_id)?;
1303                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
1304                    let language_server_name = language_server.name().to_string();
1305                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
1306                    Some((language_server_name, serialized))
1307                })
1308                .collect::<Vec<_>>()
1309        } else {
1310            Vec::new()
1311        };
1312
1313    cx.background_spawn({
1314        let snapshot = snapshot.clone();
1315        async move {
1316            let diagnostic_groups = if diagnostic_groups.is_empty()
1317                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
1318            {
1319                None
1320            } else {
1321                Some(diagnostic_groups)
1322            };
1323
1324            let input_excerpt = excerpt_for_cursor_position(
1325                cursor_point,
1326                &full_path_str,
1327                &snapshot,
1328                MAX_REWRITE_TOKENS,
1329                MAX_CONTEXT_TOKENS,
1330            );
1331            let input_events = make_events_prompt();
1332            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
1333
1334            let body = PredictEditsBody {
1335                input_events,
1336                input_excerpt: input_excerpt.prompt,
1337                can_collect_data: can_collect_data.0,
1338                diagnostic_groups,
1339                git_info,
1340                outline: None,
1341                speculated_output: None,
1342            };
1343
1344            Ok(GatherContextOutput {
1345                body,
1346                editable_range,
1347            })
1348        }
1349    })
1350}
1351
1352fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1353    let mut result = String::new();
1354    for event in events.iter().rev() {
1355        let event_string = event.to_prompt();
1356        let event_tokens = tokens_for_bytes(event_string.len());
1357        if event_tokens > remaining_tokens {
1358            break;
1359        }
1360
1361        if !result.is_empty() {
1362            result.insert_str(0, "\n\n");
1363        }
1364        result.insert_str(0, &event_string);
1365        remaining_tokens -= event_tokens;
1366    }
1367    result
1368}
1369
1370struct RegisteredBuffer {
1371    snapshot: BufferSnapshot,
1372    _subscriptions: [gpui::Subscription; 2],
1373}
1374
1375#[derive(Clone)]
1376pub enum Event {
1377    BufferChange {
1378        old_snapshot: BufferSnapshot,
1379        new_snapshot: BufferSnapshot,
1380        timestamp: Instant,
1381    },
1382}
1383
1384impl Event {
1385    fn to_prompt(&self) -> String {
1386        match self {
1387            Event::BufferChange {
1388                old_snapshot,
1389                new_snapshot,
1390                ..
1391            } => {
1392                let mut prompt = String::new();
1393
1394                let old_path = old_snapshot
1395                    .file()
1396                    .map(|f| f.path().as_ref())
1397                    .unwrap_or(Path::new("untitled"));
1398                let new_path = new_snapshot
1399                    .file()
1400                    .map(|f| f.path().as_ref())
1401                    .unwrap_or(Path::new("untitled"));
1402                if old_path != new_path {
1403                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1404                }
1405
1406                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1407                if !diff.is_empty() {
1408                    write!(
1409                        prompt,
1410                        "User edited {:?}:\n```diff\n{}\n```",
1411                        new_path, diff
1412                    )
1413                    .unwrap();
1414                }
1415
1416                prompt
1417            }
1418        }
1419    }
1420}
1421
1422#[derive(Debug, Clone)]
1423struct CurrentEditPrediction {
1424    buffer_id: EntityId,
1425    completion: EditPrediction,
1426}
1427
1428impl CurrentEditPrediction {
1429    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1430        if self.buffer_id != old_completion.buffer_id {
1431            return true;
1432        }
1433
1434        let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
1435            return true;
1436        };
1437        let Some(new_edits) = self.completion.interpolate(snapshot) else {
1438            return false;
1439        };
1440
1441        if old_edits.len() == 1 && new_edits.len() == 1 {
1442            let (old_range, old_text) = &old_edits[0];
1443            let (new_range, new_text) = &new_edits[0];
1444            new_range == old_range && new_text.starts_with(old_text)
1445        } else {
1446            true
1447        }
1448    }
1449}
1450
1451struct PendingCompletion {
1452    id: usize,
1453    _task: Task<()>,
1454}
1455
1456#[derive(Debug, Clone, Copy)]
1457pub enum DataCollectionChoice {
1458    NotAnswered,
1459    Enabled,
1460    Disabled,
1461}
1462
1463impl DataCollectionChoice {
1464    pub fn is_enabled(self) -> bool {
1465        match self {
1466            Self::Enabled => true,
1467            Self::NotAnswered | Self::Disabled => false,
1468        }
1469    }
1470
1471    pub fn is_answered(self) -> bool {
1472        match self {
1473            Self::Enabled | Self::Disabled => true,
1474            Self::NotAnswered => false,
1475        }
1476    }
1477
1478    pub fn toggle(&self) -> DataCollectionChoice {
1479        match self {
1480            Self::Enabled => Self::Disabled,
1481            Self::Disabled => Self::Enabled,
1482            Self::NotAnswered => Self::Enabled,
1483        }
1484    }
1485}
1486
1487impl From<bool> for DataCollectionChoice {
1488    fn from(value: bool) -> Self {
1489        match value {
1490            true => DataCollectionChoice::Enabled,
1491            false => DataCollectionChoice::Disabled,
1492        }
1493    }
1494}
1495
1496pub struct ProviderDataCollection {
1497    /// When set to None, data collection is not possible in the provider buffer
1498    choice: Option<Entity<DataCollectionChoice>>,
1499    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1500}
1501
1502#[derive(Debug, Clone, Copy)]
1503pub struct CanCollectData(pub bool);
1504
1505impl ProviderDataCollection {
1506    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1507        let choice_and_watcher = buffer.and_then(|buffer| {
1508            let file = buffer.read(cx).file()?;
1509
1510            if !file.is_local() || file.is_private() {
1511                return None;
1512            }
1513
1514            let zeta = zeta.read(cx);
1515            let choice = zeta.data_collection_choice.clone();
1516
1517            let license_detection_watcher = zeta
1518                .license_detection_watchers
1519                .get(&file.worktree_id(cx))
1520                .cloned()?;
1521
1522            Some((choice, license_detection_watcher))
1523        });
1524
1525        if let Some((choice, watcher)) = choice_and_watcher {
1526            ProviderDataCollection {
1527                choice: Some(choice),
1528                license_detection_watcher: Some(watcher),
1529            }
1530        } else {
1531            ProviderDataCollection {
1532                choice: None,
1533                license_detection_watcher: None,
1534            }
1535        }
1536    }
1537
1538    pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
1539        CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
1540    }
1541
1542    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1543        self.choice
1544            .as_ref()
1545            .is_some_and(|choice| choice.read(cx).is_enabled())
1546    }
1547
1548    fn is_project_open_source(&self) -> bool {
1549        self.license_detection_watcher
1550            .as_ref()
1551            .is_some_and(|watcher| watcher.is_project_open_source())
1552    }
1553
1554    pub fn toggle(&mut self, cx: &mut App) {
1555        if let Some(choice) = self.choice.as_mut() {
1556            let new_choice = choice.update(cx, |choice, _cx| {
1557                let new_choice = choice.toggle();
1558                *choice = new_choice;
1559                new_choice
1560            });
1561
1562            db::write_and_log(cx, move || {
1563                KEY_VALUE_STORE.write_kvp(
1564                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1565                    new_choice.is_enabled().to_string(),
1566                )
1567            });
1568        }
1569    }
1570}
1571
1572async fn llm_token_retry(
1573    llm_token: &LlmApiToken,
1574    client: &Arc<Client>,
1575    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
1576) -> Result<Response<AsyncBody>> {
1577    let mut did_retry = false;
1578    let http_client = client.http_client();
1579    let mut token = llm_token.acquire(client).await?;
1580    loop {
1581        let request = build_request(token.clone())?;
1582        let response = http_client.send(request).await?;
1583
1584        if !did_retry
1585            && !response.status().is_success()
1586            && response
1587                .headers()
1588                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1589                .is_some()
1590        {
1591            did_retry = true;
1592            token = llm_token.refresh(client).await?;
1593            continue;
1594        }
1595
1596        return Ok(response);
1597    }
1598}
1599
1600pub struct ZetaEditPredictionProvider {
1601    zeta: Entity<Zeta>,
1602    pending_completions: ArrayVec<PendingCompletion, 2>,
1603    next_pending_completion_id: usize,
1604    current_completion: Option<CurrentEditPrediction>,
1605    /// None if this is entirely disabled for this provider
1606    provider_data_collection: ProviderDataCollection,
1607    last_request_timestamp: Instant,
1608}
1609
1610impl ZetaEditPredictionProvider {
1611    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1612
1613    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1614        Self {
1615            zeta,
1616            pending_completions: ArrayVec::new(),
1617            next_pending_completion_id: 0,
1618            current_completion: None,
1619            provider_data_collection,
1620            last_request_timestamp: Instant::now(),
1621        }
1622    }
1623}
1624
1625impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
1626    fn name() -> &'static str {
1627        "zed-predict"
1628    }
1629
1630    fn display_name() -> &'static str {
1631        "Zed's Edit Predictions"
1632    }
1633
1634    fn show_completions_in_menu() -> bool {
1635        true
1636    }
1637
1638    fn show_tab_accept_marker() -> bool {
1639        true
1640    }
1641
1642    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1643        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1644
1645        if self.provider_data_collection.is_data_collection_enabled(cx) {
1646            DataCollectionState::Enabled {
1647                is_project_open_source,
1648            }
1649        } else {
1650            DataCollectionState::Disabled {
1651                is_project_open_source,
1652            }
1653        }
1654    }
1655
1656    fn toggle_data_collection(&mut self, cx: &mut App) {
1657        self.provider_data_collection.toggle(cx);
1658    }
1659
1660    fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1661        self.zeta.read(cx).usage(cx)
1662    }
1663
1664    fn is_enabled(
1665        &self,
1666        _buffer: &Entity<Buffer>,
1667        _cursor_position: language::Anchor,
1668        _cx: &App,
1669    ) -> bool {
1670        true
1671    }
1672    fn is_refreshing(&self) -> bool {
1673        !self.pending_completions.is_empty()
1674    }
1675
1676    fn refresh(
1677        &mut self,
1678        project: Option<Entity<Project>>,
1679        buffer: Entity<Buffer>,
1680        position: language::Anchor,
1681        _debounce: bool,
1682        cx: &mut Context<Self>,
1683    ) {
1684        if self.zeta.read(cx).update_required {
1685            return;
1686        }
1687
1688        if self
1689            .zeta
1690            .read(cx)
1691            .user_store
1692            .read_with(cx, |user_store, _cx| {
1693                user_store.account_too_young() || user_store.has_overdue_invoices()
1694            })
1695        {
1696            return;
1697        }
1698
1699        if let Some(current_completion) = self.current_completion.as_ref() {
1700            let snapshot = buffer.read(cx).snapshot();
1701            if current_completion
1702                .completion
1703                .interpolate(&snapshot)
1704                .is_some()
1705            {
1706                return;
1707            }
1708        }
1709
1710        let pending_completion_id = self.next_pending_completion_id;
1711        self.next_pending_completion_id += 1;
1712        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1713        let last_request_timestamp = self.last_request_timestamp;
1714
1715        let task = cx.spawn(async move |this, cx| {
1716            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1717                .checked_duration_since(Instant::now())
1718            {
1719                cx.background_executor().timer(timeout).await;
1720            }
1721
1722            let completion_request = this.update(cx, |this, cx| {
1723                this.last_request_timestamp = Instant::now();
1724                this.zeta.update(cx, |zeta, cx| {
1725                    zeta.request_completion(
1726                        project.as_ref(),
1727                        &buffer,
1728                        position,
1729                        can_collect_data,
1730                        cx,
1731                    )
1732                })
1733            });
1734
1735            let completion = match completion_request {
1736                Ok(completion_request) => {
1737                    let completion_request = completion_request.await;
1738                    completion_request.map(|c| {
1739                        c.map(|completion| CurrentEditPrediction {
1740                            buffer_id: buffer.entity_id(),
1741                            completion,
1742                        })
1743                    })
1744                }
1745                Err(error) => Err(error),
1746            };
1747            let Some(new_completion) = completion
1748                .context("edit prediction failed")
1749                .log_err()
1750                .flatten()
1751            else {
1752                this.update(cx, |this, cx| {
1753                    if this.pending_completions[0].id == pending_completion_id {
1754                        this.pending_completions.remove(0);
1755                    } else {
1756                        this.pending_completions.clear();
1757                    }
1758
1759                    cx.notify();
1760                })
1761                .ok();
1762                return;
1763            };
1764
1765            this.update(cx, |this, cx| {
1766                if this.pending_completions[0].id == pending_completion_id {
1767                    this.pending_completions.remove(0);
1768                } else {
1769                    this.pending_completions.clear();
1770                }
1771
1772                if let Some(old_completion) = this.current_completion.as_ref() {
1773                    let snapshot = buffer.read(cx).snapshot();
1774                    if new_completion.should_replace_completion(old_completion, &snapshot) {
1775                        this.zeta.update(cx, |zeta, cx| {
1776                            zeta.completion_shown(&new_completion.completion, cx);
1777                        });
1778                        this.current_completion = Some(new_completion);
1779                    }
1780                } else {
1781                    this.zeta.update(cx, |zeta, cx| {
1782                        zeta.completion_shown(&new_completion.completion, cx);
1783                    });
1784                    this.current_completion = Some(new_completion);
1785                }
1786
1787                cx.notify();
1788            })
1789            .ok();
1790        });
1791
1792        // We always maintain at most two pending completions. When we already
1793        // have two, we replace the newest one.
1794        if self.pending_completions.len() <= 1 {
1795            self.pending_completions.push(PendingCompletion {
1796                id: pending_completion_id,
1797                _task: task,
1798            });
1799        } else if self.pending_completions.len() == 2 {
1800            self.pending_completions.pop();
1801            self.pending_completions.push(PendingCompletion {
1802                id: pending_completion_id,
1803                _task: task,
1804            });
1805        }
1806    }
1807
1808    fn cycle(
1809        &mut self,
1810        _buffer: Entity<Buffer>,
1811        _cursor_position: language::Anchor,
1812        _direction: edit_prediction::Direction,
1813        _cx: &mut Context<Self>,
1814    ) {
1815        // Right now we don't support cycling.
1816    }
1817
1818    fn accept(&mut self, cx: &mut Context<Self>) {
1819        let completion_id = self
1820            .current_completion
1821            .as_ref()
1822            .map(|completion| completion.completion.id);
1823        if let Some(completion_id) = completion_id {
1824            self.zeta
1825                .update(cx, |zeta, cx| {
1826                    zeta.accept_edit_prediction(completion_id, cx)
1827                })
1828                .detach();
1829        }
1830        self.pending_completions.clear();
1831    }
1832
1833    fn discard(&mut self, _cx: &mut Context<Self>) {
1834        self.pending_completions.clear();
1835        self.current_completion.take();
1836    }
1837
1838    fn suggest(
1839        &mut self,
1840        buffer: &Entity<Buffer>,
1841        cursor_position: language::Anchor,
1842        cx: &mut Context<Self>,
1843    ) -> Option<edit_prediction::EditPrediction> {
1844        let CurrentEditPrediction {
1845            buffer_id,
1846            completion,
1847            ..
1848        } = self.current_completion.as_mut()?;
1849
1850        // Invalidate previous completion if it was generated for a different buffer.
1851        if *buffer_id != buffer.entity_id() {
1852            self.current_completion.take();
1853            return None;
1854        }
1855
1856        let buffer = buffer.read(cx);
1857        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1858            self.current_completion.take();
1859            return None;
1860        };
1861
1862        let cursor_row = cursor_position.to_point(buffer).row;
1863        let (closest_edit_ix, (closest_edit_range, _)) =
1864            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1865                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1866                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1867                cmp::min(distance_from_start, distance_from_end)
1868            })?;
1869
1870        let mut edit_start_ix = closest_edit_ix;
1871        for (range, _) in edits[..edit_start_ix].iter().rev() {
1872            let distance_from_closest_edit =
1873                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1874            if distance_from_closest_edit <= 1 {
1875                edit_start_ix -= 1;
1876            } else {
1877                break;
1878            }
1879        }
1880
1881        let mut edit_end_ix = closest_edit_ix + 1;
1882        for (range, _) in &edits[edit_end_ix..] {
1883            let distance_from_closest_edit =
1884                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1885            if distance_from_closest_edit <= 1 {
1886                edit_end_ix += 1;
1887            } else {
1888                break;
1889            }
1890        }
1891
1892        Some(edit_prediction::EditPrediction {
1893            id: Some(completion.id.to_string().into()),
1894            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1895            edit_preview: Some(completion.edit_preview.clone()),
1896        })
1897    }
1898}
1899
1900fn tokens_for_bytes(bytes: usize) -> usize {
1901    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1902    /// intentionally low to err on the side of underestimating limits.
1903    const BYTES_PER_TOKEN_GUESS: usize = 3;
1904    bytes / BYTES_PER_TOKEN_GUESS
1905}
1906
1907#[cfg(test)]
1908mod tests {
1909    use client::UserStore;
1910    use client::test::FakeServer;
1911    use clock::FakeSystemClock;
1912    use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
1913    use gpui::TestAppContext;
1914    use http_client::FakeHttpClient;
1915    use indoc::indoc;
1916    use language::Point;
1917    use settings::SettingsStore;
1918
1919    use super::*;
1920
1921    #[gpui::test]
1922    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1923        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1924        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1925            to_completion_edits(
1926                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1927                &buffer,
1928                cx,
1929            )
1930            .into()
1931        });
1932
1933        let edit_preview = cx
1934            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1935            .await;
1936
1937        let completion = EditPrediction {
1938            edits,
1939            edit_preview,
1940            path: Path::new("").into(),
1941            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1942            id: EditPredictionId(Uuid::new_v4()),
1943            excerpt_range: 0..0,
1944            cursor_offset: 0,
1945            input_outline: "".into(),
1946            input_events: "".into(),
1947            input_excerpt: "".into(),
1948            output_excerpt: "".into(),
1949            buffer_snapshotted_at: Instant::now(),
1950            response_received_at: Instant::now(),
1951        };
1952
1953        cx.update(|cx| {
1954            assert_eq!(
1955                from_completion_edits(
1956                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1957                    &buffer,
1958                    cx
1959                ),
1960                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1961            );
1962
1963            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1964            assert_eq!(
1965                from_completion_edits(
1966                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1967                    &buffer,
1968                    cx
1969                ),
1970                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1971            );
1972
1973            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1974            assert_eq!(
1975                from_completion_edits(
1976                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1977                    &buffer,
1978                    cx
1979                ),
1980                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1981            );
1982
1983            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1984            assert_eq!(
1985                from_completion_edits(
1986                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1987                    &buffer,
1988                    cx
1989                ),
1990                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1991            );
1992
1993            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1994            assert_eq!(
1995                from_completion_edits(
1996                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1997                    &buffer,
1998                    cx
1999                ),
2000                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2001            );
2002
2003            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2004            assert_eq!(
2005                from_completion_edits(
2006                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2007                    &buffer,
2008                    cx
2009                ),
2010                vec![(9..11, "".to_string())]
2011            );
2012
2013            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2014            assert_eq!(
2015                from_completion_edits(
2016                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2017                    &buffer,
2018                    cx
2019                ),
2020                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
2021            );
2022
2023            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2024            assert_eq!(
2025                from_completion_edits(
2026                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2027                    &buffer,
2028                    cx
2029                ),
2030                vec![(4..4, "M".to_string())]
2031            );
2032
2033            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2034            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
2035        })
2036    }
2037
2038    #[gpui::test]
2039    async fn test_clean_up_diff(cx: &mut TestAppContext) {
2040        cx.update(|cx| {
2041            let settings_store = SettingsStore::test(cx);
2042            cx.set_global(settings_store);
2043            client::init_settings(cx);
2044        });
2045
2046        let edits = edits_for_prediction(
2047            indoc! {"
2048                fn main() {
2049                    let word_1 = \"lorem\";
2050                    let range = word.len()..word.len();
2051                }
2052            "},
2053            indoc! {"
2054                <|editable_region_start|>
2055                fn main() {
2056                    let word_1 = \"lorem\";
2057                    let range = word_1.len()..word_1.len();
2058                }
2059
2060                <|editable_region_end|>
2061            "},
2062            cx,
2063        )
2064        .await;
2065        assert_eq!(
2066            edits,
2067            [
2068                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
2069                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
2070            ]
2071        );
2072
2073        let edits = edits_for_prediction(
2074            indoc! {"
2075                fn main() {
2076                    let story = \"the quick\"
2077                }
2078            "},
2079            indoc! {"
2080                <|editable_region_start|>
2081                fn main() {
2082                    let story = \"the quick brown fox jumps over the lazy dog\";
2083                }
2084
2085                <|editable_region_end|>
2086            "},
2087            cx,
2088        )
2089        .await;
2090        assert_eq!(
2091            edits,
2092            [
2093                (
2094                    Point::new(1, 26)..Point::new(1, 26),
2095                    " brown fox jumps over the lazy dog".to_string()
2096                ),
2097                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
2098            ]
2099        );
2100    }
2101
2102    #[gpui::test]
2103    async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2104        cx.update(|cx| {
2105            let settings_store = SettingsStore::test(cx);
2106            cx.set_global(settings_store);
2107            client::init_settings(cx);
2108        });
2109
2110        let buffer_content = "lorem\n";
2111        let completion_response = indoc! {"
2112            ```animals.js
2113            <|start_of_file|>
2114            <|editable_region_start|>
2115            lorem
2116            ipsum
2117            <|editable_region_end|>
2118            ```"};
2119
2120        let http_client = FakeHttpClient::create(move |req| async move {
2121            match (req.method(), req.uri().path()) {
2122                (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2123                    .status(200)
2124                    .body(
2125                        serde_json::to_string(&CreateLlmTokenResponse {
2126                            token: LlmToken("the-llm-token".to_string()),
2127                        })
2128                        .unwrap()
2129                        .into(),
2130                    )
2131                    .unwrap()),
2132                (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2133                    .status(200)
2134                    .body(
2135                        serde_json::to_string(&PredictEditsResponse {
2136                            request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
2137                                .unwrap(),
2138                            output_excerpt: completion_response.to_string(),
2139                        })
2140                        .unwrap()
2141                        .into(),
2142                    )
2143                    .unwrap()),
2144                _ => Ok(http_client::Response::builder()
2145                    .status(404)
2146                    .body("Not Found".into())
2147                    .unwrap()),
2148            }
2149        });
2150
2151        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2152        cx.update(|cx| {
2153            RefreshLlmTokenListener::register(client.clone(), cx);
2154        });
2155        // Construct the fake server to authenticate.
2156        let _server = FakeServer::for_client(42, &client, cx).await;
2157        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2158        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2159
2160        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2161        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2162        let completion_task = zeta.update(cx, |zeta, cx| {
2163            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2164        });
2165
2166        let completion = completion_task.await.unwrap().unwrap();
2167        buffer.update(cx, |buffer, cx| {
2168            buffer.edit(completion.edits.iter().cloned(), None, cx)
2169        });
2170        assert_eq!(
2171            buffer.read_with(cx, |buffer, _| buffer.text()),
2172            "lorem\nipsum"
2173        );
2174    }
2175
2176    async fn edits_for_prediction(
2177        buffer_content: &str,
2178        completion_response: &str,
2179        cx: &mut TestAppContext,
2180    ) -> Vec<(Range<Point>, String)> {
2181        let completion_response = completion_response.to_string();
2182        let http_client = FakeHttpClient::create(move |req| {
2183            let completion = completion_response.clone();
2184            async move {
2185                match (req.method(), req.uri().path()) {
2186                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2187                        .status(200)
2188                        .body(
2189                            serde_json::to_string(&CreateLlmTokenResponse {
2190                                token: LlmToken("the-llm-token".to_string()),
2191                            })
2192                            .unwrap()
2193                            .into(),
2194                        )
2195                        .unwrap()),
2196                    (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
2197                        .status(200)
2198                        .body(
2199                            serde_json::to_string(&PredictEditsResponse {
2200                                request_id: Uuid::new_v4(),
2201                                output_excerpt: completion,
2202                            })
2203                            .unwrap()
2204                            .into(),
2205                        )
2206                        .unwrap()),
2207                    _ => Ok(http_client::Response::builder()
2208                        .status(404)
2209                        .body("Not Found".into())
2210                        .unwrap()),
2211                }
2212            }
2213        });
2214
2215        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2216        cx.update(|cx| {
2217            RefreshLlmTokenListener::register(client.clone(), cx);
2218        });
2219        // Construct the fake server to authenticate.
2220        let _server = FakeServer::for_client(42, &client, cx).await;
2221        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2222        let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
2223
2224        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2225        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
2226        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2227        let completion_task = zeta.update(cx, |zeta, cx| {
2228            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
2229        });
2230
2231        let completion = completion_task.await.unwrap().unwrap();
2232        completion
2233            .edits
2234            .iter()
2235            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
2236            .collect::<Vec<_>>()
2237    }
2238
2239    fn to_completion_edits(
2240        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
2241        buffer: &Entity<Buffer>,
2242        cx: &App,
2243    ) -> Vec<(Range<Anchor>, String)> {
2244        let buffer = buffer.read(cx);
2245        iterator
2246            .into_iter()
2247            .map(|(range, text)| {
2248                (
2249                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2250                    text,
2251                )
2252            })
2253            .collect()
2254    }
2255
2256    fn from_completion_edits(
2257        editor_edits: &[(Range<Anchor>, String)],
2258        buffer: &Entity<Buffer>,
2259        cx: &App,
2260    ) -> Vec<(Range<usize>, String)> {
2261        let buffer = buffer.read(cx);
2262        editor_edits
2263            .iter()
2264            .map(|(range, text)| {
2265                (
2266                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
2267                    text.clone(),
2268                )
2269            })
2270            .collect()
2271    }
2272
2273    #[ctor::ctor]
2274    fn init_logger() {
2275        zlog::init_test();
2276    }
2277}