zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_banner;
   6mod onboarding_modal;
   7mod onboarding_telemetry;
   8mod rate_completion_modal;
   9
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::KEY_VALUE_STORE;
  12use editor::Editor;
  13pub use init::*;
  14use inline_completion::DataCollectionState;
  15pub use license_detection::is_license_eligible_for_data_collection;
  16use license_detection::LICENSE_FILES_TO_CHECK;
  17pub use onboarding_banner::*;
  18pub use rate_completion_modal::*;
  19
  20use anyhow::{anyhow, Context as _, Result};
  21use arrayvec::ArrayVec;
  22use client::{Client, UserStore};
  23use collections::{HashMap, HashSet, VecDeque};
  24use futures::AsyncReadExt;
  25use gpui::{
  26    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  27    Subscription, Task,
  28};
  29use http_client::{HttpClient, Method};
  30use input_excerpt::excerpt_for_cursor_position;
  31use language::{
  32    text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint,
  33};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use postage::watch;
  36use project::Project;
  37use release_channel::AppVersion;
  38use settings::WorktreeId;
  39use std::str::FromStr;
  40use std::{
  41    borrow::Cow,
  42    cmp,
  43    fmt::Write,
  44    future::Future,
  45    mem,
  46    ops::Range,
  47    path::Path,
  48    rc::Rc,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::InlineCompletionRating;
  53use thiserror::Error;
  54use util::ResultExt;
  55use uuid::Uuid;
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  57use workspace::Workspace;
  58use worktree::Worktree;
  59use zed_llm_client::{
  60    PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  61    MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  62};
  63
  64const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  65const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  66const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  67const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  68const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  69const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  70
  71const MAX_CONTEXT_TOKENS: usize = 150;
  72const MAX_REWRITE_TOKENS: usize = 350;
  73const MAX_EVENT_TOKENS: usize = 500;
  74
  75/// Maximum number of events to track.
  76const MAX_EVENT_COUNT: usize = 16;
  77
  78actions!(edit_prediction, [ClearHistory]);
  79
  80#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  81pub struct InlineCompletionId(Uuid);
  82
  83impl From<InlineCompletionId> for gpui::ElementId {
  84    fn from(value: InlineCompletionId) -> Self {
  85        gpui::ElementId::Uuid(value.0)
  86    }
  87}
  88
  89impl std::fmt::Display for InlineCompletionId {
  90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  91        write!(f, "{}", self.0)
  92    }
  93}
  94
  95#[derive(Clone)]
  96struct ZetaGlobal(Entity<Zeta>);
  97
  98impl Global for ZetaGlobal {}
  99
 100#[derive(Clone)]
 101pub struct InlineCompletion {
 102    id: InlineCompletionId,
 103    path: Arc<Path>,
 104    excerpt_range: Range<usize>,
 105    cursor_offset: usize,
 106    edits: Arc<[(Range<Anchor>, String)]>,
 107    snapshot: BufferSnapshot,
 108    edit_preview: EditPreview,
 109    input_outline: Arc<str>,
 110    input_events: Arc<str>,
 111    input_excerpt: Arc<str>,
 112    output_excerpt: Arc<str>,
 113    request_sent_at: Instant,
 114    response_received_at: Instant,
 115}
 116
 117impl InlineCompletion {
 118    fn latency(&self) -> Duration {
 119        self.response_received_at
 120            .duration_since(self.request_sent_at)
 121    }
 122
 123    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 124        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 125    }
 126}
 127
 128fn interpolate(
 129    old_snapshot: &BufferSnapshot,
 130    new_snapshot: &BufferSnapshot,
 131    current_edits: Arc<[(Range<Anchor>, String)]>,
 132) -> Option<Vec<(Range<Anchor>, String)>> {
 133    let mut edits = Vec::new();
 134
 135    let mut model_edits = current_edits.into_iter().peekable();
 136    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 137        while let Some((model_old_range, _)) = model_edits.peek() {
 138            let model_old_range = model_old_range.to_offset(old_snapshot);
 139            if model_old_range.end < user_edit.old.start {
 140                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 141                edits.push((model_old_range.clone(), model_new_text.clone()));
 142            } else {
 143                break;
 144            }
 145        }
 146
 147        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 148            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 149            if user_edit.old == model_old_offset_range {
 150                let user_new_text = new_snapshot
 151                    .text_for_range(user_edit.new.clone())
 152                    .collect::<String>();
 153
 154                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 155                    if !model_suffix.is_empty() {
 156                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 157                        edits.push((anchor..anchor, model_suffix.to_string()));
 158                    }
 159
 160                    model_edits.next();
 161                    continue;
 162                }
 163            }
 164        }
 165
 166        return None;
 167    }
 168
 169    edits.extend(model_edits.cloned());
 170
 171    if edits.is_empty() {
 172        None
 173    } else {
 174        Some(edits)
 175    }
 176}
 177
 178impl std::fmt::Debug for InlineCompletion {
 179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 180        f.debug_struct("InlineCompletion")
 181            .field("id", &self.id)
 182            .field("path", &self.path)
 183            .field("edits", &self.edits)
 184            .finish_non_exhaustive()
 185    }
 186}
 187
 188pub struct Zeta {
 189    editor: Option<Entity<Editor>>,
 190    client: Arc<Client>,
 191    events: VecDeque<Event>,
 192    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 193    shown_completions: VecDeque<InlineCompletion>,
 194    rated_completions: HashSet<InlineCompletionId>,
 195    data_collection_choice: Entity<DataCollectionChoice>,
 196    llm_token: LlmApiToken,
 197    _llm_token_subscription: Subscription,
 198    /// Whether the terms of service have been accepted.
 199    tos_accepted: bool,
 200    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 201    update_required: bool,
 202    _user_store_subscription: Subscription,
 203    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 204}
 205
 206impl Zeta {
 207    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 208        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 209    }
 210
 211    pub fn register(
 212        editor: Option<Entity<Editor>>,
 213        worktree: Option<Entity<Worktree>>,
 214        client: Arc<Client>,
 215        user_store: Entity<UserStore>,
 216        cx: &mut App,
 217    ) -> Entity<Self> {
 218        let this = Self::global(cx).unwrap_or_else(|| {
 219            let entity = cx.new(|cx| Self::new(editor, client, user_store, cx));
 220            cx.set_global(ZetaGlobal(entity.clone()));
 221            entity
 222        });
 223
 224        this.update(cx, move |this, cx| {
 225            if let Some(worktree) = worktree {
 226                worktree.update(cx, |worktree, cx| {
 227                    this.license_detection_watchers
 228                        .entry(worktree.id())
 229                        .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
 230                });
 231            }
 232        });
 233
 234        this
 235    }
 236
 237    pub fn clear_history(&mut self) {
 238        self.events.clear();
 239    }
 240
 241    fn new(
 242        editor: Option<Entity<Editor>>,
 243        client: Arc<Client>,
 244        user_store: Entity<UserStore>,
 245        cx: &mut Context<Self>,
 246    ) -> Self {
 247        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 248
 249        let data_collection_choice = Self::load_data_collection_choices();
 250        let data_collection_choice = cx.new(|_| data_collection_choice);
 251
 252        Self {
 253            editor,
 254            client,
 255            events: VecDeque::new(),
 256            shown_completions: VecDeque::new(),
 257            rated_completions: HashSet::default(),
 258            registered_buffers: HashMap::default(),
 259            data_collection_choice,
 260            llm_token: LlmApiToken::default(),
 261            _llm_token_subscription: cx.subscribe(
 262                &refresh_llm_token_listener,
 263                |this, _listener, _event, cx| {
 264                    let client = this.client.clone();
 265                    let llm_token = this.llm_token.clone();
 266                    cx.spawn(|_this, _cx| async move {
 267                        llm_token.refresh(&client).await?;
 268                        anyhow::Ok(())
 269                    })
 270                    .detach_and_log_err(cx);
 271                },
 272            ),
 273            tos_accepted: user_store
 274                .read(cx)
 275                .current_user_has_accepted_terms()
 276                .unwrap_or(false),
 277            update_required: false,
 278            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 279                match event {
 280                    client::user::Event::PrivateUserInfoUpdated => {
 281                        this.tos_accepted = user_store
 282                            .read(cx)
 283                            .current_user_has_accepted_terms()
 284                            .unwrap_or(false);
 285                    }
 286                    _ => {}
 287                }
 288            }),
 289            license_detection_watchers: HashMap::default(),
 290        }
 291    }
 292
 293    fn push_event(&mut self, event: Event) {
 294        if let Some(Event::BufferChange {
 295            new_snapshot: last_new_snapshot,
 296            timestamp: last_timestamp,
 297            ..
 298        }) = self.events.back_mut()
 299        {
 300            // Coalesce edits for the same buffer when they happen one after the other.
 301            let Event::BufferChange {
 302                old_snapshot,
 303                new_snapshot,
 304                timestamp,
 305            } = &event;
 306
 307            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 308                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 309                && old_snapshot.version == last_new_snapshot.version
 310            {
 311                *last_new_snapshot = new_snapshot.clone();
 312                *last_timestamp = *timestamp;
 313                return;
 314            }
 315        }
 316
 317        self.events.push_back(event);
 318        if self.events.len() >= MAX_EVENT_COUNT {
 319            self.events.drain(..MAX_EVENT_COUNT / 2);
 320        }
 321    }
 322
 323    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 324        let buffer_id = buffer.entity_id();
 325        let weak_buffer = buffer.downgrade();
 326
 327        if let std::collections::hash_map::Entry::Vacant(entry) =
 328            self.registered_buffers.entry(buffer_id)
 329        {
 330            let snapshot = buffer.read(cx).snapshot();
 331
 332            entry.insert(RegisteredBuffer {
 333                snapshot,
 334                _subscriptions: [
 335                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 336                        this.handle_buffer_event(buffer, event, cx);
 337                    }),
 338                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 339                        this.registered_buffers.remove(&weak_buffer.entity_id());
 340                    }),
 341                ],
 342            });
 343        };
 344    }
 345
 346    fn handle_buffer_event(
 347        &mut self,
 348        buffer: Entity<Buffer>,
 349        event: &language::BufferEvent,
 350        cx: &mut Context<Self>,
 351    ) {
 352        if let language::BufferEvent::Edited = event {
 353            self.report_changes_for_buffer(&buffer, cx);
 354        }
 355    }
 356
 357    #[allow(clippy::too_many_arguments)]
 358    fn request_completion_impl<F, R>(
 359        &mut self,
 360        workspace: Option<Entity<Workspace>>,
 361        project: Option<&Entity<Project>>,
 362        buffer: &Entity<Buffer>,
 363        cursor: language::Anchor,
 364        can_collect_data: bool,
 365        cx: &mut Context<Self>,
 366        perform_predict_edits: F,
 367    ) -> Task<Result<Option<InlineCompletion>>>
 368    where
 369        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 370        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 371    {
 372        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 373        let diagnostic_groups = snapshot.diagnostic_groups(None);
 374        let cursor_point = cursor.to_point(&snapshot);
 375        let cursor_offset = cursor_point.to_offset(&snapshot);
 376        let events = self.events.clone();
 377        let path: Arc<Path> = snapshot
 378            .file()
 379            .map(|f| Arc::from(f.full_path(cx).as_path()))
 380            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 381
 382        let zeta = cx.entity();
 383        let client = self.client.clone();
 384        let llm_token = self.llm_token.clone();
 385        let app_version = AppVersion::global(cx);
 386
 387        let buffer = buffer.clone();
 388
 389        let local_lsp_store =
 390            project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
 391        let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
 392            Some(
 393                diagnostic_groups
 394                    .into_iter()
 395                    .filter_map(|(language_server_id, diagnostic_group)| {
 396                        let language_server =
 397                            local_lsp_store.running_language_server_for_id(language_server_id)?;
 398
 399                        Some((
 400                            language_server.name(),
 401                            diagnostic_group.resolve::<usize>(&snapshot),
 402                        ))
 403                    })
 404                    .collect::<Vec<_>>(),
 405            )
 406        } else {
 407            None
 408        };
 409
 410        cx.spawn(|_, cx| async move {
 411            let request_sent_at = Instant::now();
 412
 413            struct BackgroundValues {
 414                input_events: String,
 415                input_excerpt: String,
 416                speculated_output: String,
 417                editable_range: Range<usize>,
 418                input_outline: String,
 419            }
 420
 421            let values = cx
 422                .background_spawn({
 423                    let snapshot = snapshot.clone();
 424                    let path = path.clone();
 425                    async move {
 426                        let path = path.to_string_lossy();
 427                        let input_excerpt = excerpt_for_cursor_position(
 428                            cursor_point,
 429                            &path,
 430                            &snapshot,
 431                            MAX_REWRITE_TOKENS,
 432                            MAX_CONTEXT_TOKENS,
 433                        );
 434                        let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
 435                        let input_outline = prompt_for_outline(&snapshot);
 436
 437                        anyhow::Ok(BackgroundValues {
 438                            input_events,
 439                            input_excerpt: input_excerpt.prompt,
 440                            speculated_output: input_excerpt.speculated_output,
 441                            editable_range: input_excerpt.editable_range.to_offset(&snapshot),
 442                            input_outline,
 443                        })
 444                    }
 445                })
 446                .await?;
 447
 448            log::debug!(
 449                "Events:\n{}\nExcerpt:\n{:?}",
 450                values.input_events,
 451                values.input_excerpt
 452            );
 453
 454            let body = PredictEditsBody {
 455                input_events: values.input_events.clone(),
 456                input_excerpt: values.input_excerpt.clone(),
 457                speculated_output: Some(values.speculated_output),
 458                outline: Some(values.input_outline.clone()),
 459                can_collect_data,
 460                diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
 461                    diagnostic_groups
 462                        .into_iter()
 463                        .map(|(name, diagnostic_group)| {
 464                            Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
 465                        })
 466                        .collect::<Result<Vec<_>>>()
 467                        .log_err()
 468                }),
 469            };
 470
 471            let response = perform_predict_edits(PerformPredictEditsParams {
 472                client,
 473                llm_token,
 474                app_version,
 475                body,
 476            })
 477            .await;
 478            let response = match response {
 479                Ok(response) => response,
 480                Err(err) => {
 481                    if err.is::<ZedUpdateRequiredError>() {
 482                        cx.update(|cx| {
 483                            zeta.update(cx, |zeta, _cx| {
 484                                zeta.update_required = true;
 485                            });
 486
 487                            if let Some(workspace) = workspace {
 488                                workspace.update(cx, |workspace, cx| {
 489                                    workspace.show_notification(
 490                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 491                                        cx,
 492                                        |cx| {
 493                                            cx.new(|_| {
 494                                                ErrorMessagePrompt::new(err.to_string())
 495                                                    .with_link_button(
 496                                                        "Update Zed",
 497                                                        "https://zed.dev/releases",
 498                                                    )
 499                                            })
 500                                        },
 501                                    );
 502                                });
 503                            }
 504                        })
 505                        .ok();
 506                    }
 507
 508                    return Err(err);
 509                }
 510            };
 511
 512            log::debug!("completion response: {}", &response.output_excerpt);
 513
 514            Self::process_completion_response(
 515                response,
 516                buffer,
 517                &snapshot,
 518                values.editable_range,
 519                cursor_offset,
 520                path,
 521                values.input_outline,
 522                values.input_events,
 523                values.input_excerpt,
 524                request_sent_at,
 525                &cx,
 526            )
 527            .await
 528        })
 529    }
 530
 531    // Generates several example completions of various states to fill the Zeta completion modal
 532    #[cfg(any(test, feature = "test-support"))]
 533    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 534        use language::Point;
 535
 536        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 537            And maybe a short line
 538
 539            Then a few lines
 540
 541            and then another
 542            "#};
 543
 544        let project = None;
 545        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 546        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 547
 548        let completion_tasks = vec![
 549            self.fake_completion(
 550                project,
 551                &buffer,
 552                position,
 553                PredictEditsResponse {
 554                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 555                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 556a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 557[here's an edit]
 558And maybe a short line
 559Then a few lines
 560and then another
 561{EDITABLE_REGION_END_MARKER}
 562                        ", ),
 563                },
 564                cx,
 565            ),
 566            self.fake_completion(
 567                project,
 568                &buffer,
 569                position,
 570                PredictEditsResponse {
 571                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 572                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 573a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 574And maybe a short line
 575[and another edit]
 576Then a few lines
 577and then another
 578{EDITABLE_REGION_END_MARKER}
 579                        "#),
 580                },
 581                cx,
 582            ),
 583            self.fake_completion(
 584                project,
 585                &buffer,
 586                position,
 587                PredictEditsResponse {
 588                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 589                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 590a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 591And maybe a short line
 592
 593Then a few lines
 594
 595and then another
 596{EDITABLE_REGION_END_MARKER}
 597                        "#),
 598                },
 599                cx,
 600            ),
 601            self.fake_completion(
 602                project,
 603                &buffer,
 604                position,
 605                PredictEditsResponse {
 606                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 607                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 608a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 609And maybe a short line
 610
 611Then a few lines
 612
 613and then another
 614{EDITABLE_REGION_END_MARKER}
 615                        "#),
 616                },
 617                cx,
 618            ),
 619            self.fake_completion(
 620                project,
 621                &buffer,
 622                position,
 623                PredictEditsResponse {
 624                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 625                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 626a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 627And maybe a short line
 628Then a few lines
 629[a third completion]
 630and then another
 631{EDITABLE_REGION_END_MARKER}
 632                        "#),
 633                },
 634                cx,
 635            ),
 636            self.fake_completion(
 637                project,
 638                &buffer,
 639                position,
 640                PredictEditsResponse {
 641                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 642                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 643a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 644And maybe a short line
 645and then another
 646[fourth completion example]
 647{EDITABLE_REGION_END_MARKER}
 648                        "#),
 649                },
 650                cx,
 651            ),
 652            self.fake_completion(
 653                project,
 654                &buffer,
 655                position,
 656                PredictEditsResponse {
 657                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 658                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 659a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 660And maybe a short line
 661Then a few lines
 662and then another
 663[fifth and final completion]
 664{EDITABLE_REGION_END_MARKER}
 665                        "#),
 666                },
 667                cx,
 668            ),
 669        ];
 670
 671        cx.spawn(|zeta, mut cx| async move {
 672            for task in completion_tasks {
 673                task.await.unwrap();
 674            }
 675
 676            zeta.update(&mut cx, |zeta, _cx| {
 677                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 678                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 679            })
 680            .ok();
 681        })
 682    }
 683
 684    #[cfg(any(test, feature = "test-support"))]
 685    pub fn fake_completion(
 686        &mut self,
 687        project: Option<&Entity<Project>>,
 688        buffer: &Entity<Buffer>,
 689        position: language::Anchor,
 690        response: PredictEditsResponse,
 691        cx: &mut Context<Self>,
 692    ) -> Task<Result<Option<InlineCompletion>>> {
 693        use std::future::ready;
 694
 695        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
 696            ready(Ok(response))
 697        })
 698    }
 699
 700    pub fn request_completion(
 701        &mut self,
 702        project: Option<&Entity<Project>>,
 703        buffer: &Entity<Buffer>,
 704        position: language::Anchor,
 705        can_collect_data: bool,
 706        cx: &mut Context<Self>,
 707    ) -> Task<Result<Option<InlineCompletion>>> {
 708        let workspace = self
 709            .editor
 710            .as_ref()
 711            .and_then(|editor| editor.read(cx).workspace());
 712        self.request_completion_impl(
 713            workspace,
 714            project,
 715            buffer,
 716            position,
 717            can_collect_data,
 718            cx,
 719            Self::perform_predict_edits,
 720        )
 721    }
 722
 723    fn perform_predict_edits(
 724        params: PerformPredictEditsParams,
 725    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 726        async move {
 727            let PerformPredictEditsParams {
 728                client,
 729                llm_token,
 730                app_version,
 731                body,
 732                ..
 733            } = params;
 734
 735            let http_client = client.http_client();
 736            let mut token = llm_token.acquire(&client).await?;
 737            let mut did_retry = false;
 738
 739            loop {
 740                let request_builder = http_client::Request::builder().method(Method::POST);
 741                let request_builder =
 742                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 743                        request_builder.uri(predict_edits_url)
 744                    } else {
 745                        request_builder.uri(
 746                            http_client
 747                                .build_zed_llm_url("/predict_edits/v2", &[])?
 748                                .as_ref(),
 749                        )
 750                    };
 751                let request = request_builder
 752                    .header("Content-Type", "application/json")
 753                    .header("Authorization", format!("Bearer {}", token))
 754                    .body(serde_json::to_string(&body)?.into())?;
 755
 756                let mut response = http_client.send(request).await?;
 757
 758                if let Some(minimum_required_version) = response
 759                    .headers()
 760                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 761                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 762                {
 763                    if app_version < minimum_required_version {
 764                        return Err(anyhow!(ZedUpdateRequiredError {
 765                            minimum_version: minimum_required_version
 766                        }));
 767                    }
 768                }
 769
 770                if response.status().is_success() {
 771                    let mut body = String::new();
 772                    response.body_mut().read_to_string(&mut body).await?;
 773                    return Ok(serde_json::from_str(&body)?);
 774                } else if !did_retry
 775                    && response
 776                        .headers()
 777                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 778                        .is_some()
 779                {
 780                    did_retry = true;
 781                    token = llm_token.refresh(&client).await?;
 782                } else {
 783                    let mut body = String::new();
 784                    response.body_mut().read_to_string(&mut body).await?;
 785                    return Err(anyhow!(
 786                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 787                        response.status(),
 788                        body
 789                    ));
 790                }
 791            }
 792        }
 793    }
 794
 795    #[allow(clippy::too_many_arguments)]
 796    fn process_completion_response(
 797        prediction_response: PredictEditsResponse,
 798        buffer: Entity<Buffer>,
 799        snapshot: &BufferSnapshot,
 800        editable_range: Range<usize>,
 801        cursor_offset: usize,
 802        path: Arc<Path>,
 803        input_outline: String,
 804        input_events: String,
 805        input_excerpt: String,
 806        request_sent_at: Instant,
 807        cx: &AsyncApp,
 808    ) -> Task<Result<Option<InlineCompletion>>> {
 809        let snapshot = snapshot.clone();
 810        let request_id = prediction_response.request_id;
 811        let output_excerpt = prediction_response.output_excerpt;
 812        cx.spawn(|cx| async move {
 813            let output_excerpt: Arc<str> = output_excerpt.into();
 814
 815            let edits: Arc<[(Range<Anchor>, String)]> = cx
 816                .background_spawn({
 817                    let output_excerpt = output_excerpt.clone();
 818                    let editable_range = editable_range.clone();
 819                    let snapshot = snapshot.clone();
 820                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 821                })
 822                .await?
 823                .into();
 824
 825            let Some((edits, snapshot, edit_preview)) = buffer.read_with(&cx, {
 826                let edits = edits.clone();
 827                |buffer, cx| {
 828                    let new_snapshot = buffer.snapshot();
 829                    let edits: Arc<[(Range<Anchor>, String)]> =
 830                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 831                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 832                }
 833            })?
 834            else {
 835                return anyhow::Ok(None);
 836            };
 837
 838            let edit_preview = edit_preview.await;
 839
 840            Ok(Some(InlineCompletion {
 841                id: InlineCompletionId(request_id),
 842                path,
 843                excerpt_range: editable_range,
 844                cursor_offset,
 845                edits,
 846                edit_preview,
 847                snapshot,
 848                input_outline: input_outline.into(),
 849                input_events: input_events.into(),
 850                input_excerpt: input_excerpt.into(),
 851                output_excerpt,
 852                request_sent_at,
 853                response_received_at: Instant::now(),
 854            }))
 855        })
 856    }
 857
 858    fn parse_edits(
 859        output_excerpt: Arc<str>,
 860        editable_range: Range<usize>,
 861        snapshot: &BufferSnapshot,
 862    ) -> Result<Vec<(Range<Anchor>, String)>> {
 863        let content = output_excerpt.replace(CURSOR_MARKER, "");
 864
 865        let start_markers = content
 866            .match_indices(EDITABLE_REGION_START_MARKER)
 867            .collect::<Vec<_>>();
 868        anyhow::ensure!(
 869            start_markers.len() == 1,
 870            "expected exactly one start marker, found {}",
 871            start_markers.len()
 872        );
 873
 874        let end_markers = content
 875            .match_indices(EDITABLE_REGION_END_MARKER)
 876            .collect::<Vec<_>>();
 877        anyhow::ensure!(
 878            end_markers.len() == 1,
 879            "expected exactly one end marker, found {}",
 880            end_markers.len()
 881        );
 882
 883        let sof_markers = content
 884            .match_indices(START_OF_FILE_MARKER)
 885            .collect::<Vec<_>>();
 886        anyhow::ensure!(
 887            sof_markers.len() <= 1,
 888            "expected at most one start-of-file marker, found {}",
 889            sof_markers.len()
 890        );
 891
 892        let codefence_start = start_markers[0].0;
 893        let content = &content[codefence_start..];
 894
 895        let newline_ix = content.find('\n').context("could not find newline")?;
 896        let content = &content[newline_ix + 1..];
 897
 898        let codefence_end = content
 899            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 900            .context("could not find end marker")?;
 901        let new_text = &content[..codefence_end];
 902
 903        let old_text = snapshot
 904            .text_for_range(editable_range.clone())
 905            .collect::<String>();
 906
 907        Ok(Self::compute_edits(
 908            old_text,
 909            new_text,
 910            editable_range.start,
 911            &snapshot,
 912        ))
 913    }
 914
 915    pub fn compute_edits(
 916        old_text: String,
 917        new_text: &str,
 918        offset: usize,
 919        snapshot: &BufferSnapshot,
 920    ) -> Vec<(Range<Anchor>, String)> {
 921        text_diff(&old_text, &new_text)
 922            .into_iter()
 923            .map(|(mut old_range, new_text)| {
 924                old_range.start += offset;
 925                old_range.end += offset;
 926
 927                let prefix_len = common_prefix(
 928                    snapshot.chars_for_range(old_range.clone()),
 929                    new_text.chars(),
 930                );
 931                old_range.start += prefix_len;
 932
 933                let suffix_len = common_prefix(
 934                    snapshot.reversed_chars_for_range(old_range.clone()),
 935                    new_text[prefix_len..].chars().rev(),
 936                );
 937                old_range.end = old_range.end.saturating_sub(suffix_len);
 938
 939                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 940                let range = if old_range.is_empty() {
 941                    let anchor = snapshot.anchor_after(old_range.start);
 942                    anchor..anchor
 943                } else {
 944                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 945                };
 946                (range, new_text)
 947            })
 948            .collect()
 949    }
 950
 951    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 952        self.rated_completions.contains(&completion_id)
 953    }
 954
 955    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 956        self.shown_completions.push_front(completion.clone());
 957        if self.shown_completions.len() > 50 {
 958            let completion = self.shown_completions.pop_back().unwrap();
 959            self.rated_completions.remove(&completion.id);
 960        }
 961        cx.notify();
 962    }
 963
 964    pub fn rate_completion(
 965        &mut self,
 966        completion: &InlineCompletion,
 967        rating: InlineCompletionRating,
 968        feedback: String,
 969        cx: &mut Context<Self>,
 970    ) {
 971        self.rated_completions.insert(completion.id);
 972        telemetry::event!(
 973            "Edit Prediction Rated",
 974            rating,
 975            input_events = completion.input_events,
 976            input_excerpt = completion.input_excerpt,
 977            input_outline = completion.input_outline,
 978            output_excerpt = completion.output_excerpt,
 979            feedback
 980        );
 981        self.client.telemetry().flush_events();
 982        cx.notify();
 983    }
 984
 985    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 986        self.shown_completions.iter()
 987    }
 988
 989    pub fn shown_completions_len(&self) -> usize {
 990        self.shown_completions.len()
 991    }
 992
 993    fn report_changes_for_buffer(
 994        &mut self,
 995        buffer: &Entity<Buffer>,
 996        cx: &mut Context<Self>,
 997    ) -> BufferSnapshot {
 998        self.register_buffer(buffer, cx);
 999
1000        let registered_buffer = self
1001            .registered_buffers
1002            .get_mut(&buffer.entity_id())
1003            .unwrap();
1004        let new_snapshot = buffer.read(cx).snapshot();
1005
1006        if new_snapshot.version != registered_buffer.snapshot.version {
1007            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1008            self.push_event(Event::BufferChange {
1009                old_snapshot,
1010                new_snapshot: new_snapshot.clone(),
1011                timestamp: Instant::now(),
1012            });
1013        }
1014
1015        new_snapshot
1016    }
1017
1018    fn load_data_collection_choices() -> DataCollectionChoice {
1019        let choice = KEY_VALUE_STORE
1020            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1021            .log_err()
1022            .flatten();
1023
1024        match choice.as_deref() {
1025            Some("true") => DataCollectionChoice::Enabled,
1026            Some("false") => DataCollectionChoice::Disabled,
1027            Some(_) => {
1028                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1029                DataCollectionChoice::NotAnswered
1030            }
1031            None => DataCollectionChoice::NotAnswered,
1032        }
1033    }
1034}
1035
1036struct PerformPredictEditsParams {
1037    pub client: Arc<Client>,
1038    pub llm_token: LlmApiToken,
1039    pub app_version: SemanticVersion,
1040    pub body: PredictEditsBody,
1041}
1042
1043#[derive(Error, Debug)]
1044#[error(
1045    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1046)]
1047pub struct ZedUpdateRequiredError {
1048    minimum_version: SemanticVersion,
1049}
1050
1051struct LicenseDetectionWatcher {
1052    is_open_source_rx: watch::Receiver<bool>,
1053    _is_open_source_task: Task<()>,
1054}
1055
1056impl LicenseDetectionWatcher {
1057    pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1058        let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1059
1060        // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1061        let task = if worktree.abs_path().is_file() {
1062            Task::ready(())
1063        } else {
1064            let loaded_files = LICENSE_FILES_TO_CHECK
1065                .iter()
1066                .map(Path::new)
1067                .map(|file| worktree.load_file(file, cx))
1068                .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1069
1070            cx.background_spawn(async move {
1071                for loaded_file in loaded_files.into_iter() {
1072                    let Ok(loaded_file) = loaded_file.await else {
1073                        continue;
1074                    };
1075
1076                    let path = &loaded_file.file.path;
1077                    if is_license_eligible_for_data_collection(&loaded_file.text) {
1078                        log::info!("detected '{path:?}' as open source license");
1079                        *is_open_source_tx.borrow_mut() = true;
1080                    } else {
1081                        log::info!("didn't detect '{path:?}' as open source license");
1082                    }
1083
1084                    // stop on the first license that successfully read
1085                    return;
1086                }
1087
1088                log::debug!("didn't find a license file to check, assuming closed source");
1089            })
1090        };
1091
1092        Self {
1093            is_open_source_rx,
1094            _is_open_source_task: task,
1095        }
1096    }
1097
1098    /// Answers false until we find out it's open source
1099    pub fn is_project_open_source(&self) -> bool {
1100        *self.is_open_source_rx.borrow()
1101    }
1102}
1103
1104fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1105    a.zip(b)
1106        .take_while(|(a, b)| a == b)
1107        .map(|(a, _)| a.len_utf8())
1108        .sum()
1109}
1110
1111fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1112    let mut input_outline = String::new();
1113
1114    writeln!(
1115        input_outline,
1116        "```{}",
1117        snapshot
1118            .file()
1119            .map_or(Cow::Borrowed("untitled"), |file| file
1120                .path()
1121                .to_string_lossy())
1122    )
1123    .unwrap();
1124
1125    if let Some(outline) = snapshot.outline(None) {
1126        for item in &outline.items {
1127            let spacing = " ".repeat(item.depth);
1128            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1129        }
1130    }
1131
1132    writeln!(input_outline, "```").unwrap();
1133
1134    input_outline
1135}
1136
1137fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1138    let mut result = String::new();
1139    for event in events.iter().rev() {
1140        let event_string = event.to_prompt();
1141        let event_tokens = tokens_for_bytes(event_string.len());
1142        if event_tokens > remaining_tokens {
1143            break;
1144        }
1145
1146        if !result.is_empty() {
1147            result.insert_str(0, "\n\n");
1148        }
1149        result.insert_str(0, &event_string);
1150        remaining_tokens -= event_tokens;
1151    }
1152    result
1153}
1154
1155struct RegisteredBuffer {
1156    snapshot: BufferSnapshot,
1157    _subscriptions: [gpui::Subscription; 2],
1158}
1159
1160#[derive(Clone)]
1161enum Event {
1162    BufferChange {
1163        old_snapshot: BufferSnapshot,
1164        new_snapshot: BufferSnapshot,
1165        timestamp: Instant,
1166    },
1167}
1168
1169impl Event {
1170    fn to_prompt(&self) -> String {
1171        match self {
1172            Event::BufferChange {
1173                old_snapshot,
1174                new_snapshot,
1175                ..
1176            } => {
1177                let mut prompt = String::new();
1178
1179                let old_path = old_snapshot
1180                    .file()
1181                    .map(|f| f.path().as_ref())
1182                    .unwrap_or(Path::new("untitled"));
1183                let new_path = new_snapshot
1184                    .file()
1185                    .map(|f| f.path().as_ref())
1186                    .unwrap_or(Path::new("untitled"));
1187                if old_path != new_path {
1188                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1189                }
1190
1191                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1192                if !diff.is_empty() {
1193                    write!(
1194                        prompt,
1195                        "User edited {:?}:\n```diff\n{}\n```",
1196                        new_path, diff
1197                    )
1198                    .unwrap();
1199                }
1200
1201                prompt
1202            }
1203        }
1204    }
1205}
1206
1207#[derive(Debug, Clone)]
1208struct CurrentInlineCompletion {
1209    buffer_id: EntityId,
1210    completion: InlineCompletion,
1211}
1212
1213impl CurrentInlineCompletion {
1214    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1215        if self.buffer_id != old_completion.buffer_id {
1216            return true;
1217        }
1218
1219        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1220            return true;
1221        };
1222        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1223            return false;
1224        };
1225
1226        if old_edits.len() == 1 && new_edits.len() == 1 {
1227            let (old_range, old_text) = &old_edits[0];
1228            let (new_range, new_text) = &new_edits[0];
1229            new_range == old_range && new_text.starts_with(old_text)
1230        } else {
1231            true
1232        }
1233    }
1234}
1235
1236struct PendingCompletion {
1237    id: usize,
1238    _task: Task<()>,
1239}
1240
1241#[derive(Debug, Clone, Copy)]
1242pub enum DataCollectionChoice {
1243    NotAnswered,
1244    Enabled,
1245    Disabled,
1246}
1247
1248impl DataCollectionChoice {
1249    pub fn is_enabled(self) -> bool {
1250        match self {
1251            Self::Enabled => true,
1252            Self::NotAnswered | Self::Disabled => false,
1253        }
1254    }
1255
1256    pub fn is_answered(self) -> bool {
1257        match self {
1258            Self::Enabled | Self::Disabled => true,
1259            Self::NotAnswered => false,
1260        }
1261    }
1262
1263    pub fn toggle(&self) -> DataCollectionChoice {
1264        match self {
1265            Self::Enabled => Self::Disabled,
1266            Self::Disabled => Self::Enabled,
1267            Self::NotAnswered => Self::Enabled,
1268        }
1269    }
1270}
1271
1272impl From<bool> for DataCollectionChoice {
1273    fn from(value: bool) -> Self {
1274        match value {
1275            true => DataCollectionChoice::Enabled,
1276            false => DataCollectionChoice::Disabled,
1277        }
1278    }
1279}
1280
1281pub struct ProviderDataCollection {
1282    /// When set to None, data collection is not possible in the provider buffer
1283    choice: Option<Entity<DataCollectionChoice>>,
1284    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1285}
1286
1287impl ProviderDataCollection {
1288    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1289        let choice_and_watcher = buffer.and_then(|buffer| {
1290            let file = buffer.read(cx).file()?;
1291
1292            if !file.is_local() || file.is_private() {
1293                return None;
1294            }
1295
1296            let zeta = zeta.read(cx);
1297            let choice = zeta.data_collection_choice.clone();
1298
1299            let license_detection_watcher = zeta
1300                .license_detection_watchers
1301                .get(&file.worktree_id(cx))
1302                .cloned()?;
1303
1304            Some((choice, license_detection_watcher))
1305        });
1306
1307        if let Some((choice, watcher)) = choice_and_watcher {
1308            ProviderDataCollection {
1309                choice: Some(choice),
1310                license_detection_watcher: Some(watcher),
1311            }
1312        } else {
1313            ProviderDataCollection {
1314                choice: None,
1315                license_detection_watcher: None,
1316            }
1317        }
1318    }
1319
1320    pub fn can_collect_data(&self, cx: &App) -> bool {
1321        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1322    }
1323
1324    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1325        self.choice
1326            .as_ref()
1327            .is_some_and(|choice| choice.read(cx).is_enabled())
1328    }
1329
1330    fn is_project_open_source(&self) -> bool {
1331        self.license_detection_watcher
1332            .as_ref()
1333            .is_some_and(|watcher| watcher.is_project_open_source())
1334    }
1335
1336    pub fn toggle(&mut self, cx: &mut App) {
1337        if let Some(choice) = self.choice.as_mut() {
1338            let new_choice = choice.update(cx, |choice, _cx| {
1339                let new_choice = choice.toggle();
1340                *choice = new_choice;
1341                new_choice
1342            });
1343
1344            db::write_and_log(cx, move || {
1345                KEY_VALUE_STORE.write_kvp(
1346                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1347                    new_choice.is_enabled().to_string(),
1348                )
1349            });
1350        }
1351    }
1352}
1353
1354pub struct ZetaInlineCompletionProvider {
1355    zeta: Entity<Zeta>,
1356    pending_completions: ArrayVec<PendingCompletion, 2>,
1357    next_pending_completion_id: usize,
1358    current_completion: Option<CurrentInlineCompletion>,
1359    /// None if this is entirely disabled for this provider
1360    provider_data_collection: ProviderDataCollection,
1361    last_request_timestamp: Instant,
1362}
1363
1364impl ZetaInlineCompletionProvider {
1365    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1366
1367    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1368        Self {
1369            zeta,
1370            pending_completions: ArrayVec::new(),
1371            next_pending_completion_id: 0,
1372            current_completion: None,
1373            provider_data_collection,
1374            last_request_timestamp: Instant::now(),
1375        }
1376    }
1377}
1378
1379impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1380    fn name() -> &'static str {
1381        "zed-predict"
1382    }
1383
1384    fn display_name() -> &'static str {
1385        "Zed's Edit Predictions"
1386    }
1387
1388    fn show_completions_in_menu() -> bool {
1389        true
1390    }
1391
1392    fn show_tab_accept_marker() -> bool {
1393        true
1394    }
1395
1396    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1397        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1398
1399        if self.provider_data_collection.is_data_collection_enabled(cx) {
1400            DataCollectionState::Enabled {
1401                is_project_open_source,
1402            }
1403        } else {
1404            DataCollectionState::Disabled {
1405                is_project_open_source,
1406            }
1407        }
1408    }
1409
1410    fn toggle_data_collection(&mut self, cx: &mut App) {
1411        self.provider_data_collection.toggle(cx);
1412    }
1413
1414    fn is_enabled(
1415        &self,
1416        _buffer: &Entity<Buffer>,
1417        _cursor_position: language::Anchor,
1418        _cx: &App,
1419    ) -> bool {
1420        true
1421    }
1422
1423    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1424        !self.zeta.read(cx).tos_accepted
1425    }
1426
1427    fn is_refreshing(&self) -> bool {
1428        !self.pending_completions.is_empty()
1429    }
1430
1431    fn refresh(
1432        &mut self,
1433        project: Option<Entity<Project>>,
1434        buffer: Entity<Buffer>,
1435        position: language::Anchor,
1436        _debounce: bool,
1437        cx: &mut Context<Self>,
1438    ) {
1439        if !self.zeta.read(cx).tos_accepted {
1440            return;
1441        }
1442
1443        if self.zeta.read(cx).update_required {
1444            return;
1445        }
1446
1447        if let Some(current_completion) = self.current_completion.as_ref() {
1448            let snapshot = buffer.read(cx).snapshot();
1449            if current_completion
1450                .completion
1451                .interpolate(&snapshot)
1452                .is_some()
1453            {
1454                return;
1455            }
1456        }
1457
1458        let pending_completion_id = self.next_pending_completion_id;
1459        self.next_pending_completion_id += 1;
1460        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1461        let last_request_timestamp = self.last_request_timestamp;
1462
1463        let task = cx.spawn(|this, mut cx| async move {
1464            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1465                .checked_duration_since(Instant::now())
1466            {
1467                cx.background_executor().timer(timeout).await;
1468            }
1469
1470            let completion_request = this.update(&mut cx, |this, cx| {
1471                this.last_request_timestamp = Instant::now();
1472                this.zeta.update(cx, |zeta, cx| {
1473                    zeta.request_completion(
1474                        project.as_ref(),
1475                        &buffer,
1476                        position,
1477                        can_collect_data,
1478                        cx,
1479                    )
1480                })
1481            });
1482
1483            let completion = match completion_request {
1484                Ok(completion_request) => {
1485                    let completion_request = completion_request.await;
1486                    completion_request.map(|c| {
1487                        c.map(|completion| CurrentInlineCompletion {
1488                            buffer_id: buffer.entity_id(),
1489                            completion,
1490                        })
1491                    })
1492                }
1493                Err(error) => Err(error),
1494            };
1495            let Some(new_completion) = completion
1496                .context("edit prediction failed")
1497                .log_err()
1498                .flatten()
1499            else {
1500                this.update(&mut cx, |this, cx| {
1501                    if this.pending_completions[0].id == pending_completion_id {
1502                        this.pending_completions.remove(0);
1503                    } else {
1504                        this.pending_completions.clear();
1505                    }
1506
1507                    cx.notify();
1508                })
1509                .ok();
1510                return;
1511            };
1512
1513            this.update(&mut cx, |this, cx| {
1514                if this.pending_completions[0].id == pending_completion_id {
1515                    this.pending_completions.remove(0);
1516                } else {
1517                    this.pending_completions.clear();
1518                }
1519
1520                if let Some(old_completion) = this.current_completion.as_ref() {
1521                    let snapshot = buffer.read(cx).snapshot();
1522                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1523                        this.zeta.update(cx, |zeta, cx| {
1524                            zeta.completion_shown(&new_completion.completion, cx);
1525                        });
1526                        this.current_completion = Some(new_completion);
1527                    }
1528                } else {
1529                    this.zeta.update(cx, |zeta, cx| {
1530                        zeta.completion_shown(&new_completion.completion, cx);
1531                    });
1532                    this.current_completion = Some(new_completion);
1533                }
1534
1535                cx.notify();
1536            })
1537            .ok();
1538        });
1539
1540        // We always maintain at most two pending completions. When we already
1541        // have two, we replace the newest one.
1542        if self.pending_completions.len() <= 1 {
1543            self.pending_completions.push(PendingCompletion {
1544                id: pending_completion_id,
1545                _task: task,
1546            });
1547        } else if self.pending_completions.len() == 2 {
1548            self.pending_completions.pop();
1549            self.pending_completions.push(PendingCompletion {
1550                id: pending_completion_id,
1551                _task: task,
1552            });
1553        }
1554    }
1555
1556    fn cycle(
1557        &mut self,
1558        _buffer: Entity<Buffer>,
1559        _cursor_position: language::Anchor,
1560        _direction: inline_completion::Direction,
1561        _cx: &mut Context<Self>,
1562    ) {
1563        // Right now we don't support cycling.
1564    }
1565
1566    fn accept(&mut self, _cx: &mut Context<Self>) {
1567        self.pending_completions.clear();
1568    }
1569
1570    fn discard(&mut self, _cx: &mut Context<Self>) {
1571        self.pending_completions.clear();
1572        self.current_completion.take();
1573    }
1574
1575    fn suggest(
1576        &mut self,
1577        buffer: &Entity<Buffer>,
1578        cursor_position: language::Anchor,
1579        cx: &mut Context<Self>,
1580    ) -> Option<inline_completion::InlineCompletion> {
1581        let CurrentInlineCompletion {
1582            buffer_id,
1583            completion,
1584            ..
1585        } = self.current_completion.as_mut()?;
1586
1587        // Invalidate previous completion if it was generated for a different buffer.
1588        if *buffer_id != buffer.entity_id() {
1589            self.current_completion.take();
1590            return None;
1591        }
1592
1593        let buffer = buffer.read(cx);
1594        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1595            self.current_completion.take();
1596            return None;
1597        };
1598
1599        let cursor_row = cursor_position.to_point(buffer).row;
1600        let (closest_edit_ix, (closest_edit_range, _)) =
1601            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1602                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1603                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1604                cmp::min(distance_from_start, distance_from_end)
1605            })?;
1606
1607        let mut edit_start_ix = closest_edit_ix;
1608        for (range, _) in edits[..edit_start_ix].iter().rev() {
1609            let distance_from_closest_edit =
1610                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1611            if distance_from_closest_edit <= 1 {
1612                edit_start_ix -= 1;
1613            } else {
1614                break;
1615            }
1616        }
1617
1618        let mut edit_end_ix = closest_edit_ix + 1;
1619        for (range, _) in &edits[edit_end_ix..] {
1620            let distance_from_closest_edit =
1621                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1622            if distance_from_closest_edit <= 1 {
1623                edit_end_ix += 1;
1624            } else {
1625                break;
1626            }
1627        }
1628
1629        Some(inline_completion::InlineCompletion {
1630            id: Some(completion.id.to_string().into()),
1631            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1632            edit_preview: Some(completion.edit_preview.clone()),
1633        })
1634    }
1635}
1636
1637fn tokens_for_bytes(bytes: usize) -> usize {
1638    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1639    /// intentionally low to err on the side of underestimating limits.
1640    const BYTES_PER_TOKEN_GUESS: usize = 3;
1641    bytes / BYTES_PER_TOKEN_GUESS
1642}
1643
1644#[cfg(test)]
1645mod tests {
1646    use client::test::FakeServer;
1647    use clock::FakeSystemClock;
1648    use gpui::TestAppContext;
1649    use http_client::FakeHttpClient;
1650    use indoc::indoc;
1651    use language::Point;
1652    use rpc::proto;
1653    use settings::SettingsStore;
1654
1655    use super::*;
1656
1657    #[gpui::test]
1658    async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1659        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1660        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1661            to_completion_edits(
1662                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1663                &buffer,
1664                cx,
1665            )
1666            .into()
1667        });
1668
1669        let edit_preview = cx
1670            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1671            .await;
1672
1673        let completion = InlineCompletion {
1674            edits,
1675            edit_preview,
1676            path: Path::new("").into(),
1677            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1678            id: InlineCompletionId(Uuid::new_v4()),
1679            excerpt_range: 0..0,
1680            cursor_offset: 0,
1681            input_outline: "".into(),
1682            input_events: "".into(),
1683            input_excerpt: "".into(),
1684            output_excerpt: "".into(),
1685            request_sent_at: Instant::now(),
1686            response_received_at: Instant::now(),
1687        };
1688
1689        cx.update(|cx| {
1690            assert_eq!(
1691                from_completion_edits(
1692                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1693                    &buffer,
1694                    cx
1695                ),
1696                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1697            );
1698
1699            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1700            assert_eq!(
1701                from_completion_edits(
1702                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1703                    &buffer,
1704                    cx
1705                ),
1706                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1707            );
1708
1709            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1710            assert_eq!(
1711                from_completion_edits(
1712                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1713                    &buffer,
1714                    cx
1715                ),
1716                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1717            );
1718
1719            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1720            assert_eq!(
1721                from_completion_edits(
1722                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1723                    &buffer,
1724                    cx
1725                ),
1726                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1727            );
1728
1729            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1730            assert_eq!(
1731                from_completion_edits(
1732                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1733                    &buffer,
1734                    cx
1735                ),
1736                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1737            );
1738
1739            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1740            assert_eq!(
1741                from_completion_edits(
1742                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1743                    &buffer,
1744                    cx
1745                ),
1746                vec![(9..11, "".to_string())]
1747            );
1748
1749            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1750            assert_eq!(
1751                from_completion_edits(
1752                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1753                    &buffer,
1754                    cx
1755                ),
1756                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1757            );
1758
1759            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1760            assert_eq!(
1761                from_completion_edits(
1762                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1763                    &buffer,
1764                    cx
1765                ),
1766                vec![(4..4, "M".to_string())]
1767            );
1768
1769            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1770            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1771        })
1772    }
1773
1774    #[gpui::test]
1775    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1776        cx.update(|cx| {
1777            let settings_store = SettingsStore::test(cx);
1778            cx.set_global(settings_store);
1779            client::init_settings(cx);
1780        });
1781
1782        let edits = edits_for_prediction(
1783            indoc! {"
1784                fn main() {
1785                    let word_1 = \"lorem\";
1786                    let range = word.len()..word.len();
1787                }
1788            "},
1789            indoc! {"
1790                <|editable_region_start|>
1791                fn main() {
1792                    let word_1 = \"lorem\";
1793                    let range = word_1.len()..word_1.len();
1794                }
1795
1796                <|editable_region_end|>
1797            "},
1798            cx,
1799        )
1800        .await;
1801        assert_eq!(
1802            edits,
1803            [
1804                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1805                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1806            ]
1807        );
1808
1809        let edits = edits_for_prediction(
1810            indoc! {"
1811                fn main() {
1812                    let story = \"the quick\"
1813                }
1814            "},
1815            indoc! {"
1816                <|editable_region_start|>
1817                fn main() {
1818                    let story = \"the quick brown fox jumps over the lazy dog\";
1819                }
1820
1821                <|editable_region_end|>
1822            "},
1823            cx,
1824        )
1825        .await;
1826        assert_eq!(
1827            edits,
1828            [
1829                (
1830                    Point::new(1, 26)..Point::new(1, 26),
1831                    " brown fox jumps over the lazy dog".to_string()
1832                ),
1833                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1834            ]
1835        );
1836    }
1837
1838    #[gpui::test]
1839    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1840        cx.update(|cx| {
1841            let settings_store = SettingsStore::test(cx);
1842            cx.set_global(settings_store);
1843            client::init_settings(cx);
1844        });
1845
1846        let buffer_content = "lorem\n";
1847        let completion_response = indoc! {"
1848            ```animals.js
1849            <|start_of_file|>
1850            <|editable_region_start|>
1851            lorem
1852            ipsum
1853            <|editable_region_end|>
1854            ```"};
1855
1856        let http_client = FakeHttpClient::create(move |_| async move {
1857            Ok(http_client::Response::builder()
1858                .status(200)
1859                .body(
1860                    serde_json::to_string(&PredictEditsResponse {
1861                        request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1862                            .unwrap(),
1863                        output_excerpt: completion_response.to_string(),
1864                    })
1865                    .unwrap()
1866                    .into(),
1867                )
1868                .unwrap())
1869        });
1870
1871        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1872        cx.update(|cx| {
1873            RefreshLlmTokenListener::register(client.clone(), cx);
1874        });
1875        let server = FakeServer::for_client(42, &client, cx).await;
1876        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1877        let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1878
1879        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1880        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1881        let completion_task = zeta.update(cx, |zeta, cx| {
1882            zeta.request_completion(None, &buffer, cursor, false, cx)
1883        });
1884
1885        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1886        server.respond(
1887            token_request.receipt(),
1888            proto::GetLlmTokenResponse { token: "".into() },
1889        );
1890
1891        let completion = completion_task.await.unwrap().unwrap();
1892        buffer.update(cx, |buffer, cx| {
1893            buffer.edit(completion.edits.iter().cloned(), None, cx)
1894        });
1895        assert_eq!(
1896            buffer.read_with(cx, |buffer, _| buffer.text()),
1897            "lorem\nipsum"
1898        );
1899    }
1900
1901    async fn edits_for_prediction(
1902        buffer_content: &str,
1903        completion_response: &str,
1904        cx: &mut TestAppContext,
1905    ) -> Vec<(Range<Point>, String)> {
1906        let completion_response = completion_response.to_string();
1907        let http_client = FakeHttpClient::create(move |_| {
1908            let completion = completion_response.clone();
1909            async move {
1910                Ok(http_client::Response::builder()
1911                    .status(200)
1912                    .body(
1913                        serde_json::to_string(&PredictEditsResponse {
1914                            request_id: Uuid::new_v4(),
1915                            output_excerpt: completion,
1916                        })
1917                        .unwrap()
1918                        .into(),
1919                    )
1920                    .unwrap())
1921            }
1922        });
1923
1924        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1925        cx.update(|cx| {
1926            RefreshLlmTokenListener::register(client.clone(), cx);
1927        });
1928        let server = FakeServer::for_client(42, &client, cx).await;
1929        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1930        let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1931
1932        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1933        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1934        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1935        let completion_task = zeta.update(cx, |zeta, cx| {
1936            zeta.request_completion(None, &buffer, cursor, false, cx)
1937        });
1938
1939        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1940        server.respond(
1941            token_request.receipt(),
1942            proto::GetLlmTokenResponse { token: "".into() },
1943        );
1944
1945        let completion = completion_task.await.unwrap().unwrap();
1946        completion
1947            .edits
1948            .into_iter()
1949            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1950            .collect::<Vec<_>>()
1951    }
1952
1953    fn to_completion_edits(
1954        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1955        buffer: &Entity<Buffer>,
1956        cx: &App,
1957    ) -> Vec<(Range<Anchor>, String)> {
1958        let buffer = buffer.read(cx);
1959        iterator
1960            .into_iter()
1961            .map(|(range, text)| {
1962                (
1963                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1964                    text,
1965                )
1966            })
1967            .collect()
1968    }
1969
1970    fn from_completion_edits(
1971        editor_edits: &[(Range<Anchor>, String)],
1972        buffer: &Entity<Buffer>,
1973        cx: &App,
1974    ) -> Vec<(Range<usize>, String)> {
1975        let buffer = buffer.read(cx);
1976        editor_edits
1977            .iter()
1978            .map(|(range, text)| {
1979                (
1980                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1981                    text.clone(),
1982                )
1983            })
1984            .collect()
1985    }
1986
1987    #[ctor::ctor]
1988    fn init_logger() {
1989        if std::env::var("RUST_LOG").is_ok() {
1990            env_logger::init();
1991        }
1992    }
1993}