zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod license_detection;
   4mod onboarding_banner;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::KEY_VALUE_STORE;
  11pub use init::*;
  12use inline_completion::DataCollectionState;
  13pub use license_detection::is_license_eligible_for_data_collection;
  14pub use onboarding_banner::*;
  15pub use rate_completion_modal::*;
  16
  17use anyhow::{anyhow, Context as _, Result};
  18use arrayvec::ArrayVec;
  19use client::{Client, UserStore};
  20use collections::{HashMap, HashSet, VecDeque};
  21use feature_flags::FeatureFlagAppExt as _;
  22use futures::AsyncReadExt;
  23use gpui::{
  24    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
  25};
  26use http_client::{HttpClient, Method};
  27use language::{
  28    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, Point, ToOffset, ToPoint,
  29};
  30use language_models::LlmApiToken;
  31use postage::watch;
  32use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  33use settings::WorktreeId;
  34use std::{
  35    borrow::Cow,
  36    cmp,
  37    fmt::Write,
  38    future::Future,
  39    mem,
  40    ops::Range,
  41    path::Path,
  42    rc::Rc,
  43    sync::Arc,
  44    time::{Duration, Instant},
  45};
  46use telemetry_events::InlineCompletionRating;
  47use util::ResultExt;
  48use uuid::Uuid;
  49use worktree::Worktree;
  50
  51const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  52const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  53const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  54const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  55const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  56const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  57
  58// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
  59
  60/// Typical number of string bytes per token for the purposes of limiting model input. This is
  61/// intentionally low to err on the side of underestimating limits.
  62const BYTES_PER_TOKEN_GUESS: usize = 3;
  63
  64/// Output token limit, used to inform the size of the input. A copy of this constant is also in
  65/// `crates/collab/src/llm.rs`.
  66const MAX_OUTPUT_TOKENS: usize = 2048;
  67
  68/// Total bytes limit for editable region of buffer excerpt.
  69///
  70/// The number of output tokens is relevant to the size of the input excerpt because the model is
  71/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
  72/// remaining for the model to specify insertions.
  73const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
  74
  75/// Total line limit for editable region of buffer excerpt.
  76const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
  77
  78/// Note that this is not the limit for the overall prompt, just for the inputs to the template
  79/// instantiated in `crates/collab/src/llm.rs`.
  80const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
  81
  82/// Maximum number of events to include in the prompt.
  83const MAX_EVENT_COUNT: usize = 16;
  84
  85/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
  86/// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
  87const PER_EVENT_BYTE_LIMIT: usize =
  88    (TOTAL_BYTE_LIMIT - BUFFER_EXCERPT_BYTE_LIMIT) / MAX_EVENT_COUNT * 4;
  89
  90actions!(edit_prediction, [ClearHistory]);
  91
  92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  93pub struct InlineCompletionId(Uuid);
  94
  95impl From<InlineCompletionId> for gpui::ElementId {
  96    fn from(value: InlineCompletionId) -> Self {
  97        gpui::ElementId::Uuid(value.0)
  98    }
  99}
 100
 101impl std::fmt::Display for InlineCompletionId {
 102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 103        write!(f, "{}", self.0)
 104    }
 105}
 106
 107impl InlineCompletionId {
 108    fn new() -> Self {
 109        Self(Uuid::new_v4())
 110    }
 111}
 112
 113#[derive(Clone)]
 114struct ZetaGlobal(Entity<Zeta>);
 115
 116impl Global for ZetaGlobal {}
 117
 118#[derive(Clone)]
 119pub struct InlineCompletion {
 120    id: InlineCompletionId,
 121    path: Arc<Path>,
 122    excerpt_range: Range<usize>,
 123    cursor_offset: usize,
 124    edits: Arc<[(Range<Anchor>, String)]>,
 125    snapshot: BufferSnapshot,
 126    edit_preview: EditPreview,
 127    input_outline: Arc<str>,
 128    input_events: Arc<str>,
 129    input_excerpt: Arc<str>,
 130    output_excerpt: Arc<str>,
 131    request_sent_at: Instant,
 132    response_received_at: Instant,
 133}
 134
 135impl InlineCompletion {
 136    fn latency(&self) -> Duration {
 137        self.response_received_at
 138            .duration_since(self.request_sent_at)
 139    }
 140
 141    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 142        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 143    }
 144}
 145
 146fn interpolate(
 147    old_snapshot: &BufferSnapshot,
 148    new_snapshot: &BufferSnapshot,
 149    current_edits: Arc<[(Range<Anchor>, String)]>,
 150) -> Option<Vec<(Range<Anchor>, String)>> {
 151    let mut edits = Vec::new();
 152
 153    let mut model_edits = current_edits.into_iter().peekable();
 154    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 155        while let Some((model_old_range, _)) = model_edits.peek() {
 156            let model_old_range = model_old_range.to_offset(old_snapshot);
 157            if model_old_range.end < user_edit.old.start {
 158                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 159                edits.push((model_old_range.clone(), model_new_text.clone()));
 160            } else {
 161                break;
 162            }
 163        }
 164
 165        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 166            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 167            if user_edit.old == model_old_offset_range {
 168                let user_new_text = new_snapshot
 169                    .text_for_range(user_edit.new.clone())
 170                    .collect::<String>();
 171
 172                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 173                    if !model_suffix.is_empty() {
 174                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 175                        edits.push((anchor..anchor, model_suffix.to_string()));
 176                    }
 177
 178                    model_edits.next();
 179                    continue;
 180                }
 181            }
 182        }
 183
 184        return None;
 185    }
 186
 187    edits.extend(model_edits.cloned());
 188
 189    if edits.is_empty() {
 190        None
 191    } else {
 192        Some(edits)
 193    }
 194}
 195
 196impl std::fmt::Debug for InlineCompletion {
 197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 198        f.debug_struct("InlineCompletion")
 199            .field("id", &self.id)
 200            .field("path", &self.path)
 201            .field("edits", &self.edits)
 202            .finish_non_exhaustive()
 203    }
 204}
 205
 206pub struct Zeta {
 207    client: Arc<Client>,
 208    events: VecDeque<Event>,
 209    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 210    shown_completions: VecDeque<InlineCompletion>,
 211    rated_completions: HashSet<InlineCompletionId>,
 212    data_collection_choice: Entity<DataCollectionChoice>,
 213    llm_token: LlmApiToken,
 214    _llm_token_subscription: Subscription,
 215    tos_accepted: bool, // Terms of service accepted
 216    _user_store_subscription: Subscription,
 217    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 218}
 219
 220impl Zeta {
 221    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 222        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 223    }
 224
 225    pub fn register(
 226        worktree: Option<Entity<Worktree>>,
 227        client: Arc<Client>,
 228        user_store: Entity<UserStore>,
 229        cx: &mut App,
 230    ) -> Entity<Self> {
 231        let this = Self::global(cx).unwrap_or_else(|| {
 232            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 233            cx.set_global(ZetaGlobal(entity.clone()));
 234            entity
 235        });
 236
 237        this.update(cx, move |this, cx| {
 238            if let Some(worktree) = worktree {
 239                worktree.update(cx, |worktree, cx| {
 240                    this.license_detection_watchers
 241                        .entry(worktree.id())
 242                        .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
 243                });
 244            }
 245        });
 246
 247        this
 248    }
 249
 250    pub fn clear_history(&mut self) {
 251        self.events.clear();
 252    }
 253
 254    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 255        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 256
 257        let data_collection_choice = Self::load_data_collection_choices();
 258        let data_collection_choice = cx.new(|_| data_collection_choice);
 259
 260        Self {
 261            client,
 262            events: VecDeque::new(),
 263            shown_completions: VecDeque::new(),
 264            rated_completions: HashSet::default(),
 265            registered_buffers: HashMap::default(),
 266            data_collection_choice,
 267            llm_token: LlmApiToken::default(),
 268            _llm_token_subscription: cx.subscribe(
 269                &refresh_llm_token_listener,
 270                |this, _listener, _event, cx| {
 271                    let client = this.client.clone();
 272                    let llm_token = this.llm_token.clone();
 273                    cx.spawn(|_this, _cx| async move {
 274                        llm_token.refresh(&client).await?;
 275                        anyhow::Ok(())
 276                    })
 277                    .detach_and_log_err(cx);
 278                },
 279            ),
 280            tos_accepted: user_store
 281                .read(cx)
 282                .current_user_has_accepted_terms()
 283                .unwrap_or(false),
 284            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 285                match event {
 286                    client::user::Event::PrivateUserInfoUpdated => {
 287                        this.tos_accepted = user_store
 288                            .read(cx)
 289                            .current_user_has_accepted_terms()
 290                            .unwrap_or(false);
 291                    }
 292                    _ => {}
 293                }
 294            }),
 295            license_detection_watchers: HashMap::default(),
 296        }
 297    }
 298
 299    fn push_event(&mut self, event: Event) {
 300        if let Some(Event::BufferChange {
 301            new_snapshot: last_new_snapshot,
 302            timestamp: last_timestamp,
 303            ..
 304        }) = self.events.back_mut()
 305        {
 306            // Coalesce edits for the same buffer when they happen one after the other.
 307            let Event::BufferChange {
 308                old_snapshot,
 309                new_snapshot,
 310                timestamp,
 311            } = &event;
 312
 313            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 314                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 315                && old_snapshot.version == last_new_snapshot.version
 316            {
 317                *last_new_snapshot = new_snapshot.clone();
 318                *last_timestamp = *timestamp;
 319                return;
 320            }
 321        }
 322
 323        self.events.push_back(event);
 324        if self.events.len() >= MAX_EVENT_COUNT {
 325            self.events.drain(..MAX_EVENT_COUNT / 2);
 326        }
 327    }
 328
 329    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 330        let buffer_id = buffer.entity_id();
 331        let weak_buffer = buffer.downgrade();
 332
 333        if let std::collections::hash_map::Entry::Vacant(entry) =
 334            self.registered_buffers.entry(buffer_id)
 335        {
 336            let snapshot = buffer.read(cx).snapshot();
 337
 338            entry.insert(RegisteredBuffer {
 339                snapshot,
 340                _subscriptions: [
 341                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 342                        this.handle_buffer_event(buffer, event, cx);
 343                    }),
 344                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 345                        this.registered_buffers.remove(&weak_buffer.entity_id());
 346                    }),
 347                ],
 348            });
 349        };
 350    }
 351
 352    fn handle_buffer_event(
 353        &mut self,
 354        buffer: Entity<Buffer>,
 355        event: &language::BufferEvent,
 356        cx: &mut Context<Self>,
 357    ) {
 358        if let language::BufferEvent::Edited = event {
 359            self.report_changes_for_buffer(&buffer, cx);
 360        }
 361    }
 362
 363    pub fn request_completion_impl<F, R>(
 364        &mut self,
 365        buffer: &Entity<Buffer>,
 366        cursor: language::Anchor,
 367        data_collection_permission: bool,
 368        cx: &mut Context<Self>,
 369        perform_predict_edits: F,
 370    ) -> Task<Result<Option<InlineCompletion>>>
 371    where
 372        F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsParams) -> R + 'static,
 373        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 374    {
 375        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 376        let cursor_point = cursor.to_point(&snapshot);
 377        let cursor_offset = cursor_point.to_offset(&snapshot);
 378        let events = self.events.clone();
 379        let path: Arc<Path> = snapshot
 380            .file()
 381            .map(|f| Arc::from(f.full_path(cx).as_path()))
 382            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 383
 384        let client = self.client.clone();
 385        let llm_token = self.llm_token.clone();
 386        let is_staff = cx.is_staff();
 387
 388        let buffer = buffer.clone();
 389        cx.spawn(|_, cx| async move {
 390            let request_sent_at = Instant::now();
 391
 392            let (input_events, input_excerpt, excerpt_range, input_outline) = cx
 393                .background_executor()
 394                .spawn({
 395                    let snapshot = snapshot.clone();
 396                    let path = path.clone();
 397                    async move {
 398                        let path = path.to_string_lossy();
 399                        let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
 400                            cursor_point,
 401                            BUFFER_EXCERPT_BYTE_LIMIT,
 402                            BUFFER_EXCERPT_LINE_LIMIT,
 403                            &path,
 404                            &snapshot,
 405                        )?;
 406                        let input_excerpt = prompt_for_excerpt(
 407                            cursor_offset,
 408                            &excerpt_range,
 409                            excerpt_len_guess,
 410                            &path,
 411                            &snapshot,
 412                        );
 413
 414                        let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
 415                        let input_events = prompt_for_events(events.iter(), bytes_remaining);
 416
 417                        // Note that input_outline is not currently used in prompt generation and so
 418                        // is not counted towards TOTAL_BYTE_LIMIT.
 419                        let input_outline = prompt_for_outline(&snapshot);
 420
 421                        anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
 422                    }
 423                })
 424                .await?;
 425
 426            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 427
 428            let body = PredictEditsParams {
 429                input_events: input_events.clone(),
 430                input_excerpt: input_excerpt.clone(),
 431                outline: Some(input_outline.clone()),
 432                data_collection_permission,
 433            };
 434
 435            let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
 436
 437            let output_excerpt = response.output_excerpt;
 438            log::debug!("completion response: {}", output_excerpt);
 439
 440            Self::process_completion_response(
 441                output_excerpt,
 442                buffer,
 443                &snapshot,
 444                excerpt_range,
 445                cursor_offset,
 446                path,
 447                input_outline,
 448                input_events,
 449                input_excerpt,
 450                request_sent_at,
 451                &cx,
 452            )
 453            .await
 454        })
 455    }
 456
 457    // Generates several example completions of various states to fill the Zeta completion modal
 458    #[cfg(any(test, feature = "test-support"))]
 459    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 460        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 461            And maybe a short line
 462
 463            Then a few lines
 464
 465            and then another
 466            "#};
 467
 468        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 469        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 470
 471        let completion_tasks = vec![
 472            self.fake_completion(
 473                &buffer,
 474                position,
 475                PredictEditsResponse {
 476                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 477a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 478[here's an edit]
 479And maybe a short line
 480Then a few lines
 481and then another
 482{EDITABLE_REGION_END_MARKER}
 483                        ", ),
 484                },
 485                cx,
 486            ),
 487            self.fake_completion(
 488                &buffer,
 489                position,
 490                PredictEditsResponse {
 491                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 492a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 493And maybe a short line
 494[and another edit]
 495Then a few lines
 496and then another
 497{EDITABLE_REGION_END_MARKER}
 498                        "#),
 499                },
 500                cx,
 501            ),
 502            self.fake_completion(
 503                &buffer,
 504                position,
 505                PredictEditsResponse {
 506                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 507a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 508And maybe a short line
 509
 510Then a few lines
 511
 512and then another
 513{EDITABLE_REGION_END_MARKER}
 514                        "#),
 515                },
 516                cx,
 517            ),
 518            self.fake_completion(
 519                &buffer,
 520                position,
 521                PredictEditsResponse {
 522                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 523a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 524And maybe a short line
 525
 526Then a few lines
 527
 528and then another
 529{EDITABLE_REGION_END_MARKER}
 530                        "#),
 531                },
 532                cx,
 533            ),
 534            self.fake_completion(
 535                &buffer,
 536                position,
 537                PredictEditsResponse {
 538                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 539a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 540And maybe a short line
 541Then a few lines
 542[a third completion]
 543and then another
 544{EDITABLE_REGION_END_MARKER}
 545                        "#),
 546                },
 547                cx,
 548            ),
 549            self.fake_completion(
 550                &buffer,
 551                position,
 552                PredictEditsResponse {
 553                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 554a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 555And maybe a short line
 556and then another
 557[fourth completion example]
 558{EDITABLE_REGION_END_MARKER}
 559                        "#),
 560                },
 561                cx,
 562            ),
 563            self.fake_completion(
 564                &buffer,
 565                position,
 566                PredictEditsResponse {
 567                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 568a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 569And maybe a short line
 570Then a few lines
 571and then another
 572[fifth and final completion]
 573{EDITABLE_REGION_END_MARKER}
 574                        "#),
 575                },
 576                cx,
 577            ),
 578        ];
 579
 580        cx.spawn(|zeta, mut cx| async move {
 581            for task in completion_tasks {
 582                task.await.unwrap();
 583            }
 584
 585            zeta.update(&mut cx, |zeta, _cx| {
 586                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 587                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 588            })
 589            .ok();
 590        })
 591    }
 592
 593    #[cfg(any(test, feature = "test-support"))]
 594    pub fn fake_completion(
 595        &mut self,
 596        buffer: &Entity<Buffer>,
 597        position: language::Anchor,
 598        response: PredictEditsResponse,
 599        cx: &mut Context<Self>,
 600    ) -> Task<Result<Option<InlineCompletion>>> {
 601        use std::future::ready;
 602
 603        self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
 604            ready(Ok(response))
 605        })
 606    }
 607
 608    pub fn request_completion(
 609        &mut self,
 610        buffer: &Entity<Buffer>,
 611        position: language::Anchor,
 612        data_collection_permission: bool,
 613        cx: &mut Context<Self>,
 614    ) -> Task<Result<Option<InlineCompletion>>> {
 615        self.request_completion_impl(
 616            buffer,
 617            position,
 618            data_collection_permission,
 619            cx,
 620            Self::perform_predict_edits,
 621        )
 622    }
 623
 624    fn perform_predict_edits(
 625        client: Arc<Client>,
 626        llm_token: LlmApiToken,
 627        _is_staff: bool,
 628        body: PredictEditsParams,
 629    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 630        async move {
 631            let http_client = client.http_client();
 632            let mut token = llm_token.acquire(&client).await?;
 633            let mut did_retry = false;
 634
 635            loop {
 636                let request_builder = http_client::Request::builder().method(Method::POST).uri(
 637                    http_client
 638                        .build_zed_llm_url("/predict_edits", &[])?
 639                        .as_ref(),
 640                );
 641                let request = request_builder
 642                    .header("Content-Type", "application/json")
 643                    .header("Authorization", format!("Bearer {}", token))
 644                    .body(serde_json::to_string(&body)?.into())?;
 645
 646                let mut response = http_client.send(request).await?;
 647
 648                if response.status().is_success() {
 649                    let mut body = String::new();
 650                    response.body_mut().read_to_string(&mut body).await?;
 651                    return Ok(serde_json::from_str(&body)?);
 652                } else if !did_retry
 653                    && response
 654                        .headers()
 655                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 656                        .is_some()
 657                {
 658                    did_retry = true;
 659                    token = llm_token.refresh(&client).await?;
 660                } else {
 661                    let mut body = String::new();
 662                    response.body_mut().read_to_string(&mut body).await?;
 663                    return Err(anyhow!(
 664                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 665                        response.status(),
 666                        body
 667                    ));
 668                }
 669            }
 670        }
 671    }
 672
 673    #[allow(clippy::too_many_arguments)]
 674    fn process_completion_response(
 675        output_excerpt: String,
 676        buffer: Entity<Buffer>,
 677        snapshot: &BufferSnapshot,
 678        excerpt_range: Range<usize>,
 679        cursor_offset: usize,
 680        path: Arc<Path>,
 681        input_outline: String,
 682        input_events: String,
 683        input_excerpt: String,
 684        request_sent_at: Instant,
 685        cx: &AsyncApp,
 686    ) -> Task<Result<Option<InlineCompletion>>> {
 687        let snapshot = snapshot.clone();
 688        cx.spawn(|cx| async move {
 689            let output_excerpt: Arc<str> = output_excerpt.into();
 690
 691            let edits: Arc<[(Range<Anchor>, String)]> = cx
 692                .background_executor()
 693                .spawn({
 694                    let output_excerpt = output_excerpt.clone();
 695                    let excerpt_range = excerpt_range.clone();
 696                    let snapshot = snapshot.clone();
 697                    async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) }
 698                })
 699                .await?
 700                .into();
 701
 702            let Some((edits, snapshot, edit_preview)) = buffer.read_with(&cx, {
 703                let edits = edits.clone();
 704                |buffer, cx| {
 705                    let new_snapshot = buffer.snapshot();
 706                    let edits: Arc<[(Range<Anchor>, String)]> =
 707                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 708                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 709                }
 710            })?
 711            else {
 712                return anyhow::Ok(None);
 713            };
 714
 715            let edit_preview = edit_preview.await;
 716
 717            Ok(Some(InlineCompletion {
 718                id: InlineCompletionId::new(),
 719                path,
 720                excerpt_range,
 721                cursor_offset,
 722                edits,
 723                edit_preview,
 724                snapshot,
 725                input_outline: input_outline.into(),
 726                input_events: input_events.into(),
 727                input_excerpt: input_excerpt.into(),
 728                output_excerpt,
 729                request_sent_at,
 730                response_received_at: Instant::now(),
 731            }))
 732        })
 733    }
 734
 735    fn parse_edits(
 736        output_excerpt: Arc<str>,
 737        excerpt_range: Range<usize>,
 738        snapshot: &BufferSnapshot,
 739    ) -> Result<Vec<(Range<Anchor>, String)>> {
 740        let content = output_excerpt.replace(CURSOR_MARKER, "");
 741
 742        let start_markers = content
 743            .match_indices(EDITABLE_REGION_START_MARKER)
 744            .collect::<Vec<_>>();
 745        anyhow::ensure!(
 746            start_markers.len() == 1,
 747            "expected exactly one start marker, found {}",
 748            start_markers.len()
 749        );
 750
 751        let end_markers = content
 752            .match_indices(EDITABLE_REGION_END_MARKER)
 753            .collect::<Vec<_>>();
 754        anyhow::ensure!(
 755            end_markers.len() == 1,
 756            "expected exactly one end marker, found {}",
 757            end_markers.len()
 758        );
 759
 760        let sof_markers = content
 761            .match_indices(START_OF_FILE_MARKER)
 762            .collect::<Vec<_>>();
 763        anyhow::ensure!(
 764            sof_markers.len() <= 1,
 765            "expected at most one start-of-file marker, found {}",
 766            sof_markers.len()
 767        );
 768
 769        let codefence_start = start_markers[0].0;
 770        let content = &content[codefence_start..];
 771
 772        let newline_ix = content.find('\n').context("could not find newline")?;
 773        let content = &content[newline_ix + 1..];
 774
 775        let codefence_end = content
 776            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 777            .context("could not find end marker")?;
 778        let new_text = &content[..codefence_end];
 779
 780        let old_text = snapshot
 781            .text_for_range(excerpt_range.clone())
 782            .collect::<String>();
 783
 784        Ok(Self::compute_edits(
 785            old_text,
 786            new_text,
 787            excerpt_range.start,
 788            &snapshot,
 789        ))
 790    }
 791
 792    pub fn compute_edits(
 793        old_text: String,
 794        new_text: &str,
 795        offset: usize,
 796        snapshot: &BufferSnapshot,
 797    ) -> Vec<(Range<Anchor>, String)> {
 798        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 799
 800        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 801        let mut old_start = offset;
 802        for change in diff.iter_all_changes() {
 803            let value = change.value();
 804            match change.tag() {
 805                similar::ChangeTag::Equal => {
 806                    old_start += value.len();
 807                }
 808                similar::ChangeTag::Delete => {
 809                    let old_end = old_start + value.len();
 810                    if let Some((last_old_range, _)) = edits.last_mut() {
 811                        if last_old_range.end == old_start {
 812                            last_old_range.end = old_end;
 813                        } else {
 814                            edits.push((old_start..old_end, String::new()));
 815                        }
 816                    } else {
 817                        edits.push((old_start..old_end, String::new()));
 818                    }
 819                    old_start = old_end;
 820                }
 821                similar::ChangeTag::Insert => {
 822                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 823                        if last_old_range.end == old_start {
 824                            last_new_text.push_str(value);
 825                        } else {
 826                            edits.push((old_start..old_start, value.into()));
 827                        }
 828                    } else {
 829                        edits.push((old_start..old_start, value.into()));
 830                    }
 831                }
 832            }
 833        }
 834
 835        edits
 836            .into_iter()
 837            .map(|(mut old_range, new_text)| {
 838                let prefix_len = common_prefix(
 839                    snapshot.chars_for_range(old_range.clone()),
 840                    new_text.chars(),
 841                );
 842                old_range.start += prefix_len;
 843                let suffix_len = common_prefix(
 844                    snapshot.reversed_chars_for_range(old_range.clone()),
 845                    new_text[prefix_len..].chars().rev(),
 846                );
 847                old_range.end = old_range.end.saturating_sub(suffix_len);
 848
 849                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 850                let range = if old_range.is_empty() {
 851                    let anchor = snapshot.anchor_after(old_range.start);
 852                    anchor..anchor
 853                } else {
 854                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 855                };
 856                (range, new_text)
 857            })
 858            .collect()
 859    }
 860
 861    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 862        self.rated_completions.contains(&completion_id)
 863    }
 864
 865    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 866        self.shown_completions.push_front(completion.clone());
 867        if self.shown_completions.len() > 50 {
 868            let completion = self.shown_completions.pop_back().unwrap();
 869            self.rated_completions.remove(&completion.id);
 870        }
 871        cx.notify();
 872    }
 873
 874    pub fn rate_completion(
 875        &mut self,
 876        completion: &InlineCompletion,
 877        rating: InlineCompletionRating,
 878        feedback: String,
 879        cx: &mut Context<Self>,
 880    ) {
 881        self.rated_completions.insert(completion.id);
 882        telemetry::event!(
 883            "Edit Prediction Rated",
 884            rating,
 885            input_events = completion.input_events,
 886            input_excerpt = completion.input_excerpt,
 887            input_outline = completion.input_outline,
 888            output_excerpt = completion.output_excerpt,
 889            feedback
 890        );
 891        self.client.telemetry().flush_events();
 892        cx.notify();
 893    }
 894
 895    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 896        self.shown_completions.iter()
 897    }
 898
 899    pub fn shown_completions_len(&self) -> usize {
 900        self.shown_completions.len()
 901    }
 902
 903    fn report_changes_for_buffer(
 904        &mut self,
 905        buffer: &Entity<Buffer>,
 906        cx: &mut Context<Self>,
 907    ) -> BufferSnapshot {
 908        self.register_buffer(buffer, cx);
 909
 910        let registered_buffer = self
 911            .registered_buffers
 912            .get_mut(&buffer.entity_id())
 913            .unwrap();
 914        let new_snapshot = buffer.read(cx).snapshot();
 915
 916        if new_snapshot.version != registered_buffer.snapshot.version {
 917            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 918            self.push_event(Event::BufferChange {
 919                old_snapshot,
 920                new_snapshot: new_snapshot.clone(),
 921                timestamp: Instant::now(),
 922            });
 923        }
 924
 925        new_snapshot
 926    }
 927
 928    fn load_data_collection_choices() -> DataCollectionChoice {
 929        let choice = KEY_VALUE_STORE
 930            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 931            .log_err()
 932            .flatten();
 933
 934        match choice.as_deref() {
 935            Some("true") => DataCollectionChoice::Enabled,
 936            Some("false") => DataCollectionChoice::Disabled,
 937            Some(_) => {
 938                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
 939                DataCollectionChoice::NotAnswered
 940            }
 941            None => DataCollectionChoice::NotAnswered,
 942        }
 943    }
 944}
 945
 946struct LicenseDetectionWatcher {
 947    is_open_source_rx: watch::Receiver<bool>,
 948    _is_open_source_task: Task<()>,
 949}
 950
 951impl LicenseDetectionWatcher {
 952    pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
 953        let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
 954
 955        const LICENSE_FILES_TO_CHECK: [&'static str; 2] = ["LICENSE", "LICENCE"]; // US and UK English spelling
 956
 957        // Check if worktree is a single file, if so we do not need to check for a LICENSE file
 958        let task = if worktree.abs_path().is_file() {
 959            Task::ready(())
 960        } else {
 961            let loaded_files_task = futures::future::join_all(
 962                LICENSE_FILES_TO_CHECK
 963                    .iter()
 964                    .map(|file| worktree.load_file(Path::new(file), cx)),
 965            );
 966
 967            cx.background_executor().spawn(async move {
 968                for loaded_file in loaded_files_task.await {
 969                    if let Some(content) = loaded_file.log_err() {
 970                        if is_license_eligible_for_data_collection(&content.text) {
 971                            *is_open_source_tx.borrow_mut() = true;
 972                            break;
 973                        }
 974                    }
 975                }
 976            })
 977        };
 978
 979        Self {
 980            is_open_source_rx,
 981            _is_open_source_task: task,
 982        }
 983    }
 984
 985    /// Answers false until we find out it's open source
 986    pub fn is_open_source(&self) -> bool {
 987        *self.is_open_source_rx.borrow()
 988    }
 989}
 990
 991fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 992    a.zip(b)
 993        .take_while(|(a, b)| a == b)
 994        .map(|(a, _)| a.len_utf8())
 995        .sum()
 996}
 997
 998fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 999    let mut input_outline = String::new();
1000
1001    writeln!(
1002        input_outline,
1003        "```{}",
1004        snapshot
1005            .file()
1006            .map_or(Cow::Borrowed("untitled"), |file| file
1007                .path()
1008                .to_string_lossy())
1009    )
1010    .unwrap();
1011
1012    if let Some(outline) = snapshot.outline(None) {
1013        let guess_size = outline.items.len() * 15;
1014        input_outline.reserve(guess_size);
1015        for item in outline.items.iter() {
1016            let spacing = " ".repeat(item.depth);
1017            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1018        }
1019    }
1020
1021    writeln!(input_outline, "```").unwrap();
1022
1023    input_outline
1024}
1025
1026fn prompt_for_excerpt(
1027    offset: usize,
1028    excerpt_range: &Range<usize>,
1029    mut len_guess: usize,
1030    path: &str,
1031    snapshot: &BufferSnapshot,
1032) -> String {
1033    let point_range = excerpt_range.to_point(snapshot);
1034
1035    // Include one line of extra context before and after editable range, if those lines are non-empty.
1036    let extra_context_before_range =
1037        if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
1038            let range =
1039                (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
1040            len_guess += range.end - range.start;
1041            Some(range)
1042        } else {
1043            None
1044        };
1045    let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
1046        && !snapshot.is_line_blank(point_range.end.row + 1)
1047    {
1048        let range = (point_range.end
1049            ..Point::new(
1050                point_range.end.row + 1,
1051                snapshot.line_len(point_range.end.row + 1),
1052            ))
1053            .to_offset(snapshot);
1054        len_guess += range.end - range.start;
1055        Some(range)
1056    } else {
1057        None
1058    };
1059
1060    let mut prompt_excerpt = String::with_capacity(len_guess);
1061    writeln!(prompt_excerpt, "```{}", path).unwrap();
1062
1063    if excerpt_range.start == 0 {
1064        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
1065    }
1066
1067    if let Some(extra_context_before_range) = extra_context_before_range {
1068        for chunk in snapshot.text_for_range(extra_context_before_range) {
1069            prompt_excerpt.push_str(chunk);
1070        }
1071    }
1072    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
1073    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
1074        prompt_excerpt.push_str(chunk);
1075    }
1076    prompt_excerpt.push_str(CURSOR_MARKER);
1077    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
1078        prompt_excerpt.push_str(chunk);
1079    }
1080    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
1081
1082    if let Some(extra_context_after_range) = extra_context_after_range {
1083        for chunk in snapshot.text_for_range(extra_context_after_range) {
1084            prompt_excerpt.push_str(chunk);
1085        }
1086    }
1087
1088    write!(prompt_excerpt, "\n```").unwrap();
1089    debug_assert!(
1090        prompt_excerpt.len() <= len_guess,
1091        "Excerpt length {} exceeds estimated length {}",
1092        prompt_excerpt.len(),
1093        len_guess
1094    );
1095    prompt_excerpt
1096}
1097
1098fn excerpt_range_for_position(
1099    cursor_point: Point,
1100    byte_limit: usize,
1101    line_limit: u32,
1102    path: &str,
1103    snapshot: &BufferSnapshot,
1104) -> Result<(Range<usize>, usize)> {
1105    let cursor_row = cursor_point.row;
1106    let last_buffer_row = snapshot.max_point().row;
1107
1108    // This is an overestimate because it includes parts of prompt_for_excerpt which are
1109    // conditionally skipped.
1110    let mut len_guess = 0;
1111    len_guess += "```".len() + path.len() + 1;
1112    len_guess += START_OF_FILE_MARKER.len() + 1;
1113    len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
1114    len_guess += CURSOR_MARKER.len();
1115    len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
1116    len_guess += "```".len() + 1;
1117
1118    len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
1119
1120    if len_guess > byte_limit {
1121        return Err(anyhow!("Current line too long to send to model."));
1122    }
1123
1124    let mut excerpt_start_row = cursor_row;
1125    let mut excerpt_end_row = cursor_row;
1126    let mut no_more_before = cursor_row == 0;
1127    let mut no_more_after = cursor_row >= last_buffer_row;
1128    let mut row_delta = 1;
1129    loop {
1130        if !no_more_before {
1131            let row = cursor_point.row - row_delta;
1132            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
1133            let mut new_len_guess = len_guess + line_len;
1134            if row == 0 {
1135                new_len_guess += START_OF_FILE_MARKER.len() + 1;
1136            }
1137            if new_len_guess <= byte_limit {
1138                len_guess = new_len_guess;
1139                excerpt_start_row = row;
1140                if row == 0 {
1141                    no_more_before = true;
1142                }
1143            } else {
1144                no_more_before = true;
1145            }
1146        }
1147        if excerpt_end_row - excerpt_start_row >= line_limit {
1148            break;
1149        }
1150        if !no_more_after {
1151            let row = cursor_point.row + row_delta;
1152            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
1153            let new_len_guess = len_guess + line_len;
1154            if new_len_guess <= byte_limit {
1155                len_guess = new_len_guess;
1156                excerpt_end_row = row;
1157                if row >= last_buffer_row {
1158                    no_more_after = true;
1159                }
1160            } else {
1161                no_more_after = true;
1162            }
1163        }
1164        if excerpt_end_row - excerpt_start_row >= line_limit {
1165            break;
1166        }
1167        if no_more_before && no_more_after {
1168            break;
1169        }
1170        row_delta += 1;
1171    }
1172
1173    let excerpt_start = Point::new(excerpt_start_row, 0);
1174    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
1175    Ok((
1176        excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
1177        len_guess,
1178    ))
1179}
1180
1181fn prompt_for_events<'a>(
1182    events: impl Iterator<Item = &'a Event>,
1183    mut bytes_remaining: usize,
1184) -> String {
1185    let mut result = String::new();
1186    for event in events {
1187        if !result.is_empty() {
1188            result.push('\n');
1189            result.push('\n');
1190        }
1191        let event_string = event.to_prompt();
1192        let len = event_string.len();
1193        if len > PER_EVENT_BYTE_LIMIT {
1194            continue;
1195        }
1196        if len > bytes_remaining {
1197            break;
1198        }
1199        bytes_remaining -= len;
1200        result.push_str(&event_string);
1201    }
1202    result
1203}
1204
1205struct RegisteredBuffer {
1206    snapshot: BufferSnapshot,
1207    _subscriptions: [gpui::Subscription; 2],
1208}
1209
1210#[derive(Clone)]
1211enum Event {
1212    BufferChange {
1213        old_snapshot: BufferSnapshot,
1214        new_snapshot: BufferSnapshot,
1215        timestamp: Instant,
1216    },
1217}
1218
1219impl Event {
1220    fn to_prompt(&self) -> String {
1221        match self {
1222            Event::BufferChange {
1223                old_snapshot,
1224                new_snapshot,
1225                ..
1226            } => {
1227                let mut prompt = String::new();
1228
1229                let old_path = old_snapshot
1230                    .file()
1231                    .map(|f| f.path().as_ref())
1232                    .unwrap_or(Path::new("untitled"));
1233                let new_path = new_snapshot
1234                    .file()
1235                    .map(|f| f.path().as_ref())
1236                    .unwrap_or(Path::new("untitled"));
1237                if old_path != new_path {
1238                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1239                }
1240
1241                let diff =
1242                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1243                        .unified_diff()
1244                        .to_string();
1245                if !diff.is_empty() {
1246                    write!(
1247                        prompt,
1248                        "User edited {:?}:\n```diff\n{}\n```",
1249                        new_path, diff
1250                    )
1251                    .unwrap();
1252                }
1253
1254                prompt
1255            }
1256        }
1257    }
1258}
1259
1260#[derive(Debug, Clone)]
1261struct CurrentInlineCompletion {
1262    buffer_id: EntityId,
1263    completion: InlineCompletion,
1264}
1265
1266impl CurrentInlineCompletion {
1267    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1268        if self.buffer_id != old_completion.buffer_id {
1269            return true;
1270        }
1271
1272        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1273            return true;
1274        };
1275        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1276            return false;
1277        };
1278
1279        if old_edits.len() == 1 && new_edits.len() == 1 {
1280            let (old_range, old_text) = &old_edits[0];
1281            let (new_range, new_text) = &new_edits[0];
1282            new_range == old_range && new_text.starts_with(old_text)
1283        } else {
1284            true
1285        }
1286    }
1287}
1288
1289struct PendingCompletion {
1290    id: usize,
1291    _task: Task<()>,
1292}
1293
1294#[derive(Debug, Clone, Copy)]
1295pub enum DataCollectionChoice {
1296    NotAnswered,
1297    Enabled,
1298    Disabled,
1299}
1300
1301impl DataCollectionChoice {
1302    pub fn is_enabled(self) -> bool {
1303        match self {
1304            Self::Enabled => true,
1305            Self::NotAnswered | Self::Disabled => false,
1306        }
1307    }
1308
1309    pub fn is_answered(self) -> bool {
1310        match self {
1311            Self::Enabled | Self::Disabled => true,
1312            Self::NotAnswered => false,
1313        }
1314    }
1315
1316    pub fn toggle(&self) -> DataCollectionChoice {
1317        match self {
1318            Self::Enabled => Self::Disabled,
1319            Self::Disabled => Self::Enabled,
1320            Self::NotAnswered => Self::Enabled,
1321        }
1322    }
1323}
1324
1325impl From<bool> for DataCollectionChoice {
1326    fn from(value: bool) -> Self {
1327        match value {
1328            true => DataCollectionChoice::Enabled,
1329            false => DataCollectionChoice::Disabled,
1330        }
1331    }
1332}
1333
1334pub struct ProviderDataCollection {
1335    /// When set to None, data collection is not possible in the provider buffer
1336    choice: Option<Entity<DataCollectionChoice>>,
1337    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1338}
1339
1340impl ProviderDataCollection {
1341    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1342        let choice_and_watcher = buffer.and_then(|buffer| {
1343            let file = buffer.read(cx).file()?;
1344
1345            if !file.is_local() || file.is_private() {
1346                return None;
1347            }
1348
1349            let zeta = zeta.read(cx);
1350            let choice = zeta.data_collection_choice.clone();
1351
1352            // Unwrap safety: there should be a watcher for each worktree
1353            let license_detection_watcher = zeta
1354                .license_detection_watchers
1355                .get(&file.worktree_id(cx))
1356                .cloned()?;
1357
1358            Some((choice, license_detection_watcher))
1359        });
1360
1361        if let Some((choice, watcher)) = choice_and_watcher {
1362            ProviderDataCollection {
1363                choice: Some(choice),
1364                license_detection_watcher: Some(watcher),
1365            }
1366        } else {
1367            ProviderDataCollection {
1368                choice: None,
1369                license_detection_watcher: None,
1370            }
1371        }
1372    }
1373
1374    pub fn user_data_collection_choice(&self, cx: &App) -> bool {
1375        self.choice
1376            .as_ref()
1377            .map_or(false, |choice| choice.read(cx).is_enabled())
1378    }
1379
1380    pub fn data_collection_permission(&self, cx: &App) -> bool {
1381        self.choice
1382            .as_ref()
1383            .is_some_and(|choice| choice.read(cx).is_enabled())
1384            && self
1385                .license_detection_watcher
1386                .as_ref()
1387                .is_some_and(|watcher| watcher.is_open_source())
1388    }
1389
1390    pub fn toggle(&mut self, cx: &mut App) {
1391        if let Some(choice) = self.choice.as_mut() {
1392            let new_choice = choice.update(cx, |choice, _cx| {
1393                let new_choice = choice.toggle();
1394                *choice = new_choice;
1395                new_choice
1396            });
1397
1398            db::write_and_log(cx, move || {
1399                KEY_VALUE_STORE.write_kvp(
1400                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1401                    new_choice.is_enabled().to_string(),
1402                )
1403            });
1404        }
1405    }
1406}
1407
1408pub struct ZetaInlineCompletionProvider {
1409    zeta: Entity<Zeta>,
1410    pending_completions: ArrayVec<PendingCompletion, 2>,
1411    next_pending_completion_id: usize,
1412    current_completion: Option<CurrentInlineCompletion>,
1413    /// None if this is entirely disabled for this provider
1414    provider_data_collection: ProviderDataCollection,
1415    last_request_timestamp: Instant,
1416}
1417
1418impl ZetaInlineCompletionProvider {
1419    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1420
1421    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1422        Self {
1423            zeta,
1424            pending_completions: ArrayVec::new(),
1425            next_pending_completion_id: 0,
1426            current_completion: None,
1427            provider_data_collection,
1428            last_request_timestamp: Instant::now(),
1429        }
1430    }
1431}
1432
1433impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1434    fn name() -> &'static str {
1435        "zed-predict"
1436    }
1437
1438    fn display_name() -> &'static str {
1439        "Zed's Edit Predictions"
1440    }
1441
1442    fn show_completions_in_menu() -> bool {
1443        true
1444    }
1445
1446    fn show_completions_in_normal_mode() -> bool {
1447        true
1448    }
1449
1450    fn show_tab_accept_marker() -> bool {
1451        true
1452    }
1453
1454    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1455        if self
1456            .provider_data_collection
1457            .user_data_collection_choice(cx)
1458        {
1459            DataCollectionState::Enabled
1460        } else {
1461            DataCollectionState::Disabled
1462        }
1463    }
1464
1465    fn toggle_data_collection(&mut self, cx: &mut App) {
1466        self.provider_data_collection.toggle(cx);
1467    }
1468
1469    fn is_enabled(
1470        &self,
1471        _buffer: &Entity<Buffer>,
1472        _cursor_position: language::Anchor,
1473        _cx: &App,
1474    ) -> bool {
1475        true
1476    }
1477
1478    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1479        !self.zeta.read(cx).tos_accepted
1480    }
1481
1482    fn is_refreshing(&self) -> bool {
1483        !self.pending_completions.is_empty()
1484    }
1485
1486    fn refresh(
1487        &mut self,
1488        buffer: Entity<Buffer>,
1489        position: language::Anchor,
1490        _debounce: bool,
1491        cx: &mut Context<Self>,
1492    ) {
1493        if !self.zeta.read(cx).tos_accepted {
1494            return;
1495        }
1496
1497        if let Some(current_completion) = self.current_completion.as_ref() {
1498            let snapshot = buffer.read(cx).snapshot();
1499            if current_completion
1500                .completion
1501                .interpolate(&snapshot)
1502                .is_some()
1503            {
1504                return;
1505            }
1506        }
1507
1508        let pending_completion_id = self.next_pending_completion_id;
1509        self.next_pending_completion_id += 1;
1510        let data_collection_permission =
1511            self.provider_data_collection.data_collection_permission(cx);
1512        let last_request_timestamp = self.last_request_timestamp;
1513
1514        let task = cx.spawn(|this, mut cx| async move {
1515            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1516                .checked_duration_since(Instant::now())
1517            {
1518                cx.background_executor().timer(timeout).await;
1519            }
1520
1521            let completion_request = this.update(&mut cx, |this, cx| {
1522                this.last_request_timestamp = Instant::now();
1523                this.zeta.update(cx, |zeta, cx| {
1524                    zeta.request_completion(&buffer, position, data_collection_permission, cx)
1525                })
1526            });
1527
1528            let completion = match completion_request {
1529                Ok(completion_request) => {
1530                    let completion_request = completion_request.await;
1531                    completion_request.map(|c| {
1532                        c.map(|completion| CurrentInlineCompletion {
1533                            buffer_id: buffer.entity_id(),
1534                            completion,
1535                        })
1536                    })
1537                }
1538                Err(error) => Err(error),
1539            };
1540            let Some(new_completion) = completion
1541                .context("edit prediction failed")
1542                .log_err()
1543                .flatten()
1544            else {
1545                this.update(&mut cx, |this, cx| {
1546                    if this.pending_completions[0].id == pending_completion_id {
1547                        this.pending_completions.remove(0);
1548                    } else {
1549                        this.pending_completions.clear();
1550                    }
1551
1552                    cx.notify();
1553                })
1554                .ok();
1555                return;
1556            };
1557
1558            this.update(&mut cx, |this, cx| {
1559                if this.pending_completions[0].id == pending_completion_id {
1560                    this.pending_completions.remove(0);
1561                } else {
1562                    this.pending_completions.clear();
1563                }
1564
1565                if let Some(old_completion) = this.current_completion.as_ref() {
1566                    let snapshot = buffer.read(cx).snapshot();
1567                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1568                        this.zeta.update(cx, |zeta, cx| {
1569                            zeta.completion_shown(&new_completion.completion, cx);
1570                        });
1571                        this.current_completion = Some(new_completion);
1572                    }
1573                } else {
1574                    this.zeta.update(cx, |zeta, cx| {
1575                        zeta.completion_shown(&new_completion.completion, cx);
1576                    });
1577                    this.current_completion = Some(new_completion);
1578                }
1579
1580                cx.notify();
1581            })
1582            .ok();
1583        });
1584
1585        // We always maintain at most two pending completions. When we already
1586        // have two, we replace the newest one.
1587        if self.pending_completions.len() <= 1 {
1588            self.pending_completions.push(PendingCompletion {
1589                id: pending_completion_id,
1590                _task: task,
1591            });
1592        } else if self.pending_completions.len() == 2 {
1593            self.pending_completions.pop();
1594            self.pending_completions.push(PendingCompletion {
1595                id: pending_completion_id,
1596                _task: task,
1597            });
1598        }
1599    }
1600
1601    fn cycle(
1602        &mut self,
1603        _buffer: Entity<Buffer>,
1604        _cursor_position: language::Anchor,
1605        _direction: inline_completion::Direction,
1606        _cx: &mut Context<Self>,
1607    ) {
1608        // Right now we don't support cycling.
1609    }
1610
1611    fn accept(&mut self, _cx: &mut Context<Self>) {
1612        self.pending_completions.clear();
1613    }
1614
1615    fn discard(&mut self, _cx: &mut Context<Self>) {
1616        self.pending_completions.clear();
1617        self.current_completion.take();
1618    }
1619
1620    fn suggest(
1621        &mut self,
1622        buffer: &Entity<Buffer>,
1623        cursor_position: language::Anchor,
1624        cx: &mut Context<Self>,
1625    ) -> Option<inline_completion::InlineCompletion> {
1626        let CurrentInlineCompletion {
1627            buffer_id,
1628            completion,
1629            ..
1630        } = self.current_completion.as_mut()?;
1631
1632        // Invalidate previous completion if it was generated for a different buffer.
1633        if *buffer_id != buffer.entity_id() {
1634            self.current_completion.take();
1635            return None;
1636        }
1637
1638        let buffer = buffer.read(cx);
1639        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1640            self.current_completion.take();
1641            return None;
1642        };
1643
1644        let cursor_row = cursor_position.to_point(buffer).row;
1645        let (closest_edit_ix, (closest_edit_range, _)) =
1646            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1647                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1648                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1649                cmp::min(distance_from_start, distance_from_end)
1650            })?;
1651
1652        let mut edit_start_ix = closest_edit_ix;
1653        for (range, _) in edits[..edit_start_ix].iter().rev() {
1654            let distance_from_closest_edit =
1655                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1656            if distance_from_closest_edit <= 1 {
1657                edit_start_ix -= 1;
1658            } else {
1659                break;
1660            }
1661        }
1662
1663        let mut edit_end_ix = closest_edit_ix + 1;
1664        for (range, _) in &edits[edit_end_ix..] {
1665            let distance_from_closest_edit =
1666                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1667            if distance_from_closest_edit <= 1 {
1668                edit_end_ix += 1;
1669            } else {
1670                break;
1671            }
1672        }
1673
1674        Some(inline_completion::InlineCompletion {
1675            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1676            edit_preview: Some(completion.edit_preview.clone()),
1677        })
1678    }
1679}
1680
1681#[cfg(test)]
1682mod tests {
1683    use client::test::FakeServer;
1684    use clock::FakeSystemClock;
1685    use gpui::TestAppContext;
1686    use http_client::FakeHttpClient;
1687    use indoc::indoc;
1688    use language_models::RefreshLlmTokenListener;
1689    use rpc::proto;
1690    use settings::SettingsStore;
1691
1692    use super::*;
1693
1694    #[gpui::test]
1695    async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1696        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1697        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1698            to_completion_edits(
1699                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1700                &buffer,
1701                cx,
1702            )
1703            .into()
1704        });
1705
1706        let edit_preview = cx
1707            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1708            .await;
1709
1710        let completion = InlineCompletion {
1711            edits,
1712            edit_preview,
1713            path: Path::new("").into(),
1714            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1715            id: InlineCompletionId::new(),
1716            excerpt_range: 0..0,
1717            cursor_offset: 0,
1718            input_outline: "".into(),
1719            input_events: "".into(),
1720            input_excerpt: "".into(),
1721            output_excerpt: "".into(),
1722            request_sent_at: Instant::now(),
1723            response_received_at: Instant::now(),
1724        };
1725
1726        cx.update(|cx| {
1727            assert_eq!(
1728                from_completion_edits(
1729                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1730                    &buffer,
1731                    cx
1732                ),
1733                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1734            );
1735
1736            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1737            assert_eq!(
1738                from_completion_edits(
1739                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1740                    &buffer,
1741                    cx
1742                ),
1743                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1744            );
1745
1746            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1747            assert_eq!(
1748                from_completion_edits(
1749                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1750                    &buffer,
1751                    cx
1752                ),
1753                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1754            );
1755
1756            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1757            assert_eq!(
1758                from_completion_edits(
1759                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1760                    &buffer,
1761                    cx
1762                ),
1763                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1764            );
1765
1766            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1767            assert_eq!(
1768                from_completion_edits(
1769                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1770                    &buffer,
1771                    cx
1772                ),
1773                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1774            );
1775
1776            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1777            assert_eq!(
1778                from_completion_edits(
1779                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1780                    &buffer,
1781                    cx
1782                ),
1783                vec![(9..11, "".to_string())]
1784            );
1785
1786            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1787            assert_eq!(
1788                from_completion_edits(
1789                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1790                    &buffer,
1791                    cx
1792                ),
1793                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1794            );
1795
1796            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1797            assert_eq!(
1798                from_completion_edits(
1799                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1800                    &buffer,
1801                    cx
1802                ),
1803                vec![(4..4, "M".to_string())]
1804            );
1805
1806            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1807            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1808        })
1809    }
1810
1811    #[gpui::test]
1812    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1813        cx.update(|cx| {
1814            let settings_store = SettingsStore::test(cx);
1815            cx.set_global(settings_store);
1816            client::init_settings(cx);
1817        });
1818
1819        let buffer_content = "lorem\n";
1820        let completion_response = indoc! {"
1821            ```animals.js
1822            <|start_of_file|>
1823            <|editable_region_start|>
1824            lorem
1825            ipsum
1826            <|editable_region_end|>
1827            ```"};
1828
1829        let http_client = FakeHttpClient::create(move |_| async move {
1830            Ok(http_client::Response::builder()
1831                .status(200)
1832                .body(
1833                    serde_json::to_string(&PredictEditsResponse {
1834                        output_excerpt: completion_response.to_string(),
1835                    })
1836                    .unwrap()
1837                    .into(),
1838                )
1839                .unwrap())
1840        });
1841
1842        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1843        cx.update(|cx| {
1844            RefreshLlmTokenListener::register(client.clone(), cx);
1845        });
1846        let server = FakeServer::for_client(42, &client, cx).await;
1847        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1848        let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1849
1850        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1851        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1852        let completion_task = zeta.update(cx, |zeta, cx| {
1853            zeta.request_completion(&buffer, cursor, false, cx)
1854        });
1855
1856        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1857        server.respond(
1858            token_request.receipt(),
1859            proto::GetLlmTokenResponse { token: "".into() },
1860        );
1861
1862        let completion = completion_task.await.unwrap().unwrap();
1863        buffer.update(cx, |buffer, cx| {
1864            buffer.edit(completion.edits.iter().cloned(), None, cx)
1865        });
1866        assert_eq!(
1867            buffer.read_with(cx, |buffer, _| buffer.text()),
1868            "lorem\nipsum"
1869        );
1870    }
1871
1872    fn to_completion_edits(
1873        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1874        buffer: &Entity<Buffer>,
1875        cx: &App,
1876    ) -> Vec<(Range<Anchor>, String)> {
1877        let buffer = buffer.read(cx);
1878        iterator
1879            .into_iter()
1880            .map(|(range, text)| {
1881                (
1882                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1883                    text,
1884                )
1885            })
1886            .collect()
1887    }
1888
1889    fn from_completion_edits(
1890        editor_edits: &[(Range<Anchor>, String)],
1891        buffer: &Entity<Buffer>,
1892        cx: &App,
1893    ) -> Vec<(Range<usize>, String)> {
1894        let buffer = buffer.read(cx);
1895        editor_edits
1896            .iter()
1897            .map(|(range, text)| {
1898                (
1899                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1900                    text.clone(),
1901                )
1902            })
1903            .collect()
1904    }
1905
1906    #[ctor::ctor]
1907    fn init_logger() {
1908        if std::env::var("RUST_LOG").is_ok() {
1909            env_logger::init();
1910        }
1911    }
1912}