zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod license_detection;
   4mod onboarding_banner;
   5mod onboarding_modal;
   6mod rate_completion_modal;
   7
   8pub(crate) use completion_diff_element::*;
   9use db::kvp::KEY_VALUE_STORE;
  10pub use init::*;
  11use inline_completion::DataCollectionState;
  12pub use license_detection::is_license_eligible_for_data_collection;
  13pub use onboarding_banner::*;
  14pub use rate_completion_modal::*;
  15
  16use anyhow::{anyhow, Context as _, Result};
  17use arrayvec::ArrayVec;
  18use client::{Client, UserStore};
  19use collections::{HashMap, HashSet, VecDeque};
  20use feature_flags::FeatureFlagAppExt as _;
  21use futures::AsyncReadExt;
  22use gpui::{
  23    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
  24};
  25use http_client::{HttpClient, Method};
  26use language::{
  27    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview,
  28    OffsetRangeExt, Point, ToOffset, ToPoint,
  29};
  30use language_models::LlmApiToken;
  31use postage::watch;
  32use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  33use settings::WorktreeId;
  34use std::{
  35    borrow::Cow,
  36    cmp,
  37    fmt::Write,
  38    future::Future,
  39    mem,
  40    ops::Range,
  41    path::Path,
  42    rc::Rc,
  43    sync::Arc,
  44    time::{Duration, Instant},
  45};
  46use telemetry_events::InlineCompletionRating;
  47use util::ResultExt;
  48use uuid::Uuid;
  49use worktree::Worktree;
  50
  51const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  52const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  53const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  54const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  55const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  56const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  57
  58// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
  59
  60/// Typical number of string bytes per token for the purposes of limiting model input. This is
  61/// intentionally low to err on the side of underestimating limits.
  62const BYTES_PER_TOKEN_GUESS: usize = 3;
  63
  64/// Output token limit, used to inform the size of the input. A copy of this constant is also in
  65/// `crates/collab/src/llm.rs`.
  66const MAX_OUTPUT_TOKENS: usize = 2048;
  67
  68/// Total bytes limit for editable region of buffer excerpt.
  69///
  70/// The number of output tokens is relevant to the size of the input excerpt because the model is
  71/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
  72/// remaining for the model to specify insertions.
  73const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
  74
  75/// Total line limit for editable region of buffer excerpt.
  76const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
  77
  78/// Note that this is not the limit for the overall prompt, just for the inputs to the template
  79/// instantiated in `crates/collab/src/llm.rs`.
  80const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
  81
  82/// Maximum number of events to include in the prompt.
  83const MAX_EVENT_COUNT: usize = 16;
  84
  85/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
  86/// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
  87const PER_EVENT_BYTE_LIMIT: usize =
  88    (TOTAL_BYTE_LIMIT - BUFFER_EXCERPT_BYTE_LIMIT) / MAX_EVENT_COUNT * 4;
  89
  90actions!(edit_prediction, [ClearHistory]);
  91
  92#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  93pub struct InlineCompletionId(Uuid);
  94
  95impl From<InlineCompletionId> for gpui::ElementId {
  96    fn from(value: InlineCompletionId) -> Self {
  97        gpui::ElementId::Uuid(value.0)
  98    }
  99}
 100
 101impl std::fmt::Display for InlineCompletionId {
 102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 103        write!(f, "{}", self.0)
 104    }
 105}
 106
 107impl InlineCompletionId {
 108    fn new() -> Self {
 109        Self(Uuid::new_v4())
 110    }
 111}
 112
 113#[derive(Clone)]
 114struct ZetaGlobal(Entity<Zeta>);
 115
 116impl Global for ZetaGlobal {}
 117
 118#[derive(Clone)]
 119pub struct InlineCompletion {
 120    id: InlineCompletionId,
 121    path: Arc<Path>,
 122    excerpt_range: Range<usize>,
 123    cursor_offset: usize,
 124    edits: Arc<[(Range<Anchor>, String)]>,
 125    snapshot: BufferSnapshot,
 126    edit_preview: EditPreview,
 127    input_outline: Arc<str>,
 128    input_events: Arc<str>,
 129    input_excerpt: Arc<str>,
 130    output_excerpt: Arc<str>,
 131    request_sent_at: Instant,
 132    response_received_at: Instant,
 133}
 134
 135impl InlineCompletion {
 136    fn latency(&self) -> Duration {
 137        self.response_received_at
 138            .duration_since(self.request_sent_at)
 139    }
 140
 141    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 142        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 143    }
 144}
 145
 146fn interpolate(
 147    old_snapshot: &BufferSnapshot,
 148    new_snapshot: &BufferSnapshot,
 149    current_edits: Arc<[(Range<Anchor>, String)]>,
 150) -> Option<Vec<(Range<Anchor>, String)>> {
 151    let mut edits = Vec::new();
 152
 153    let mut model_edits = current_edits.into_iter().peekable();
 154    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 155        while let Some((model_old_range, _)) = model_edits.peek() {
 156            let model_old_range = model_old_range.to_offset(old_snapshot);
 157            if model_old_range.end < user_edit.old.start {
 158                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 159                edits.push((model_old_range.clone(), model_new_text.clone()));
 160            } else {
 161                break;
 162            }
 163        }
 164
 165        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 166            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 167            if user_edit.old == model_old_offset_range {
 168                let user_new_text = new_snapshot
 169                    .text_for_range(user_edit.new.clone())
 170                    .collect::<String>();
 171
 172                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 173                    if !model_suffix.is_empty() {
 174                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 175                        edits.push((anchor..anchor, model_suffix.to_string()));
 176                    }
 177
 178                    model_edits.next();
 179                    continue;
 180                }
 181            }
 182        }
 183
 184        return None;
 185    }
 186
 187    edits.extend(model_edits.cloned());
 188
 189    if edits.is_empty() {
 190        None
 191    } else {
 192        Some(edits)
 193    }
 194}
 195
 196impl std::fmt::Debug for InlineCompletion {
 197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 198        f.debug_struct("InlineCompletion")
 199            .field("id", &self.id)
 200            .field("path", &self.path)
 201            .field("edits", &self.edits)
 202            .finish_non_exhaustive()
 203    }
 204}
 205
 206pub struct Zeta {
 207    client: Arc<Client>,
 208    events: VecDeque<Event>,
 209    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 210    shown_completions: VecDeque<InlineCompletion>,
 211    rated_completions: HashSet<InlineCompletionId>,
 212    data_collection_choice: Entity<DataCollectionChoice>,
 213    llm_token: LlmApiToken,
 214    _llm_token_subscription: Subscription,
 215    tos_accepted: bool, // Terms of service accepted
 216    _user_store_subscription: Subscription,
 217    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 218}
 219
 220impl Zeta {
 221    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 222        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 223    }
 224
 225    pub fn register(
 226        worktree: Option<Entity<Worktree>>,
 227        client: Arc<Client>,
 228        user_store: Entity<UserStore>,
 229        cx: &mut App,
 230    ) -> Entity<Self> {
 231        let this = Self::global(cx).unwrap_or_else(|| {
 232            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 233            cx.set_global(ZetaGlobal(entity.clone()));
 234            entity
 235        });
 236
 237        this.update(cx, move |this, cx| {
 238            if let Some(worktree) = worktree {
 239                worktree.update(cx, |worktree, cx| {
 240                    this.license_detection_watchers
 241                        .entry(worktree.id())
 242                        .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
 243                });
 244            }
 245        });
 246
 247        this
 248    }
 249
 250    pub fn clear_history(&mut self) {
 251        self.events.clear();
 252    }
 253
 254    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 255        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 256
 257        let data_collection_choice = Self::load_data_collection_choices();
 258        let data_collection_choice = cx.new(|_| data_collection_choice);
 259
 260        Self {
 261            client,
 262            events: VecDeque::new(),
 263            shown_completions: VecDeque::new(),
 264            rated_completions: HashSet::default(),
 265            registered_buffers: HashMap::default(),
 266            data_collection_choice,
 267            llm_token: LlmApiToken::default(),
 268            _llm_token_subscription: cx.subscribe(
 269                &refresh_llm_token_listener,
 270                |this, _listener, _event, cx| {
 271                    let client = this.client.clone();
 272                    let llm_token = this.llm_token.clone();
 273                    cx.spawn(|_this, _cx| async move {
 274                        llm_token.refresh(&client).await?;
 275                        anyhow::Ok(())
 276                    })
 277                    .detach_and_log_err(cx);
 278                },
 279            ),
 280            tos_accepted: user_store
 281                .read(cx)
 282                .current_user_has_accepted_terms()
 283                .unwrap_or(false),
 284            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 285                match event {
 286                    client::user::Event::PrivateUserInfoUpdated => {
 287                        this.tos_accepted = user_store
 288                            .read(cx)
 289                            .current_user_has_accepted_terms()
 290                            .unwrap_or(false);
 291                    }
 292                    _ => {}
 293                }
 294            }),
 295            license_detection_watchers: HashMap::default(),
 296        }
 297    }
 298
 299    fn push_event(&mut self, event: Event) {
 300        if let Some(Event::BufferChange {
 301            new_snapshot: last_new_snapshot,
 302            timestamp: last_timestamp,
 303            ..
 304        }) = self.events.back_mut()
 305        {
 306            // Coalesce edits for the same buffer when they happen one after the other.
 307            let Event::BufferChange {
 308                old_snapshot,
 309                new_snapshot,
 310                timestamp,
 311            } = &event;
 312
 313            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 314                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 315                && old_snapshot.version == last_new_snapshot.version
 316            {
 317                *last_new_snapshot = new_snapshot.clone();
 318                *last_timestamp = *timestamp;
 319                return;
 320            }
 321        }
 322
 323        self.events.push_back(event);
 324        if self.events.len() >= MAX_EVENT_COUNT {
 325            self.events.drain(..MAX_EVENT_COUNT / 2);
 326        }
 327    }
 328
 329    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 330        let buffer_id = buffer.entity_id();
 331        let weak_buffer = buffer.downgrade();
 332
 333        if let std::collections::hash_map::Entry::Vacant(entry) =
 334            self.registered_buffers.entry(buffer_id)
 335        {
 336            let snapshot = buffer.read(cx).snapshot();
 337
 338            entry.insert(RegisteredBuffer {
 339                snapshot,
 340                _subscriptions: [
 341                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 342                        this.handle_buffer_event(buffer, event, cx);
 343                    }),
 344                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 345                        this.registered_buffers.remove(&weak_buffer.entity_id());
 346                    }),
 347                ],
 348            });
 349        };
 350    }
 351
 352    fn handle_buffer_event(
 353        &mut self,
 354        buffer: Entity<Buffer>,
 355        event: &language::BufferEvent,
 356        cx: &mut Context<Self>,
 357    ) {
 358        if let language::BufferEvent::Edited = event {
 359            self.report_changes_for_buffer(&buffer, cx);
 360        }
 361    }
 362
 363    pub fn request_completion_impl<F, R>(
 364        &mut self,
 365        buffer: &Entity<Buffer>,
 366        cursor: language::Anchor,
 367        data_collection_permission: bool,
 368        cx: &mut Context<Self>,
 369        perform_predict_edits: F,
 370    ) -> Task<Result<Option<InlineCompletion>>>
 371    where
 372        F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsParams) -> R + 'static,
 373        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 374    {
 375        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 376        let cursor_point = cursor.to_point(&snapshot);
 377        let cursor_offset = cursor_point.to_offset(&snapshot);
 378        let events = self.events.clone();
 379        let path: Arc<Path> = snapshot
 380            .file()
 381            .map(|f| Arc::from(f.full_path(cx).as_path()))
 382            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 383
 384        let client = self.client.clone();
 385        let llm_token = self.llm_token.clone();
 386        let is_staff = cx.is_staff();
 387
 388        let buffer = buffer.clone();
 389        cx.spawn(|_, cx| async move {
 390            let request_sent_at = Instant::now();
 391
 392            let (input_events, input_excerpt, excerpt_range, input_outline) = cx
 393                .background_executor()
 394                .spawn({
 395                    let snapshot = snapshot.clone();
 396                    let path = path.clone();
 397                    async move {
 398                        let path = path.to_string_lossy();
 399                        let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
 400                            cursor_point,
 401                            BUFFER_EXCERPT_BYTE_LIMIT,
 402                            BUFFER_EXCERPT_LINE_LIMIT,
 403                            &path,
 404                            &snapshot,
 405                        )?;
 406                        let input_excerpt = prompt_for_excerpt(
 407                            cursor_offset,
 408                            &excerpt_range,
 409                            excerpt_len_guess,
 410                            &path,
 411                            &snapshot,
 412                        );
 413
 414                        let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
 415                        let input_events = prompt_for_events(events.iter(), bytes_remaining);
 416
 417                        // Note that input_outline is not currently used in prompt generation and so
 418                        // is not counted towards TOTAL_BYTE_LIMIT.
 419                        let input_outline = prompt_for_outline(&snapshot);
 420
 421                        anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
 422                    }
 423                })
 424                .await?;
 425
 426            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 427
 428            let body = PredictEditsParams {
 429                input_events: input_events.clone(),
 430                input_excerpt: input_excerpt.clone(),
 431                outline: Some(input_outline.clone()),
 432                data_collection_permission,
 433            };
 434
 435            let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
 436
 437            let output_excerpt = response.output_excerpt;
 438            log::debug!("completion response: {}", output_excerpt);
 439
 440            Self::process_completion_response(
 441                output_excerpt,
 442                buffer,
 443                &snapshot,
 444                excerpt_range,
 445                cursor_offset,
 446                path,
 447                input_outline,
 448                input_events,
 449                input_excerpt,
 450                request_sent_at,
 451                &cx,
 452            )
 453            .await
 454        })
 455    }
 456
 457    // Generates several example completions of various states to fill the Zeta completion modal
 458    #[cfg(any(test, feature = "test-support"))]
 459    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 460        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 461            And maybe a short line
 462
 463            Then a few lines
 464
 465            and then another
 466            "#};
 467
 468        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 469        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 470
 471        let completion_tasks = vec![
 472            self.fake_completion(
 473                &buffer,
 474                position,
 475                PredictEditsResponse {
 476                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 477a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 478[here's an edit]
 479And maybe a short line
 480Then a few lines
 481and then another
 482{EDITABLE_REGION_END_MARKER}
 483                        ", ),
 484                },
 485                cx,
 486            ),
 487            self.fake_completion(
 488                &buffer,
 489                position,
 490                PredictEditsResponse {
 491                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 492a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 493And maybe a short line
 494[and another edit]
 495Then a few lines
 496and then another
 497{EDITABLE_REGION_END_MARKER}
 498                        "#),
 499                },
 500                cx,
 501            ),
 502            self.fake_completion(
 503                &buffer,
 504                position,
 505                PredictEditsResponse {
 506                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 507a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 508And maybe a short line
 509
 510Then a few lines
 511
 512and then another
 513{EDITABLE_REGION_END_MARKER}
 514                        "#),
 515                },
 516                cx,
 517            ),
 518            self.fake_completion(
 519                &buffer,
 520                position,
 521                PredictEditsResponse {
 522                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 523a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 524And maybe a short line
 525
 526Then a few lines
 527
 528and then another
 529{EDITABLE_REGION_END_MARKER}
 530                        "#),
 531                },
 532                cx,
 533            ),
 534            self.fake_completion(
 535                &buffer,
 536                position,
 537                PredictEditsResponse {
 538                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 539a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 540And maybe a short line
 541Then a few lines
 542[a third completion]
 543and then another
 544{EDITABLE_REGION_END_MARKER}
 545                        "#),
 546                },
 547                cx,
 548            ),
 549            self.fake_completion(
 550                &buffer,
 551                position,
 552                PredictEditsResponse {
 553                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 554a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 555And maybe a short line
 556and then another
 557[fourth completion example]
 558{EDITABLE_REGION_END_MARKER}
 559                        "#),
 560                },
 561                cx,
 562            ),
 563            self.fake_completion(
 564                &buffer,
 565                position,
 566                PredictEditsResponse {
 567                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 568a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 569And maybe a short line
 570Then a few lines
 571and then another
 572[fifth and final completion]
 573{EDITABLE_REGION_END_MARKER}
 574                        "#),
 575                },
 576                cx,
 577            ),
 578        ];
 579
 580        cx.spawn(|zeta, mut cx| async move {
 581            for task in completion_tasks {
 582                task.await.unwrap();
 583            }
 584
 585            zeta.update(&mut cx, |zeta, _cx| {
 586                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 587                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 588            })
 589            .ok();
 590        })
 591    }
 592
 593    #[cfg(any(test, feature = "test-support"))]
 594    pub fn fake_completion(
 595        &mut self,
 596        buffer: &Entity<Buffer>,
 597        position: language::Anchor,
 598        response: PredictEditsResponse,
 599        cx: &mut Context<Self>,
 600    ) -> Task<Result<Option<InlineCompletion>>> {
 601        use std::future::ready;
 602
 603        self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
 604            ready(Ok(response))
 605        })
 606    }
 607
 608    pub fn request_completion(
 609        &mut self,
 610        buffer: &Entity<Buffer>,
 611        position: language::Anchor,
 612        data_collection_permission: bool,
 613        cx: &mut Context<Self>,
 614    ) -> Task<Result<Option<InlineCompletion>>> {
 615        self.request_completion_impl(
 616            buffer,
 617            position,
 618            data_collection_permission,
 619            cx,
 620            Self::perform_predict_edits,
 621        )
 622    }
 623
 624    fn perform_predict_edits(
 625        client: Arc<Client>,
 626        llm_token: LlmApiToken,
 627        _is_staff: bool,
 628        body: PredictEditsParams,
 629    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 630        async move {
 631            let http_client = client.http_client();
 632            let mut token = llm_token.acquire(&client).await?;
 633            let mut did_retry = false;
 634
 635            loop {
 636                let request_builder = http_client::Request::builder().method(Method::POST).uri(
 637                    http_client
 638                        .build_zed_llm_url("/predict_edits", &[])?
 639                        .as_ref(),
 640                );
 641                let request = request_builder
 642                    .header("Content-Type", "application/json")
 643                    .header("Authorization", format!("Bearer {}", token))
 644                    .body(serde_json::to_string(&body)?.into())?;
 645
 646                let mut response = http_client.send(request).await?;
 647
 648                if response.status().is_success() {
 649                    let mut body = String::new();
 650                    response.body_mut().read_to_string(&mut body).await?;
 651                    return Ok(serde_json::from_str(&body)?);
 652                } else if !did_retry
 653                    && response
 654                        .headers()
 655                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 656                        .is_some()
 657                {
 658                    did_retry = true;
 659                    token = llm_token.refresh(&client).await?;
 660                } else {
 661                    let mut body = String::new();
 662                    response.body_mut().read_to_string(&mut body).await?;
 663                    return Err(anyhow!(
 664                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 665                        response.status(),
 666                        body
 667                    ));
 668                }
 669            }
 670        }
 671    }
 672
 673    #[allow(clippy::too_many_arguments)]
 674    fn process_completion_response(
 675        output_excerpt: String,
 676        buffer: Entity<Buffer>,
 677        snapshot: &BufferSnapshot,
 678        excerpt_range: Range<usize>,
 679        cursor_offset: usize,
 680        path: Arc<Path>,
 681        input_outline: String,
 682        input_events: String,
 683        input_excerpt: String,
 684        request_sent_at: Instant,
 685        cx: &AsyncApp,
 686    ) -> Task<Result<Option<InlineCompletion>>> {
 687        let snapshot = snapshot.clone();
 688        cx.spawn(|cx| async move {
 689            let output_excerpt: Arc<str> = output_excerpt.into();
 690
 691            let edits: Arc<[(Range<Anchor>, String)]> = cx
 692                .background_executor()
 693                .spawn({
 694                    let output_excerpt = output_excerpt.clone();
 695                    let excerpt_range = excerpt_range.clone();
 696                    let snapshot = snapshot.clone();
 697                    async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) }
 698                })
 699                .await?
 700                .into();
 701
 702            let Some((edits, snapshot, edit_preview)) = buffer.read_with(&cx, {
 703                let edits = edits.clone();
 704                |buffer, cx| {
 705                    let new_snapshot = buffer.snapshot();
 706                    let edits: Arc<[(Range<Anchor>, String)]> =
 707                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 708                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 709                }
 710            })?
 711            else {
 712                return anyhow::Ok(None);
 713            };
 714
 715            let edit_preview = edit_preview.await;
 716
 717            Ok(Some(InlineCompletion {
 718                id: InlineCompletionId::new(),
 719                path,
 720                excerpt_range,
 721                cursor_offset,
 722                edits,
 723                edit_preview,
 724                snapshot,
 725                input_outline: input_outline.into(),
 726                input_events: input_events.into(),
 727                input_excerpt: input_excerpt.into(),
 728                output_excerpt,
 729                request_sent_at,
 730                response_received_at: Instant::now(),
 731            }))
 732        })
 733    }
 734
 735    fn parse_edits(
 736        output_excerpt: Arc<str>,
 737        excerpt_range: Range<usize>,
 738        snapshot: &BufferSnapshot,
 739    ) -> Result<Vec<(Range<Anchor>, String)>> {
 740        let content = output_excerpt.replace(CURSOR_MARKER, "");
 741
 742        let start_markers = content
 743            .match_indices(EDITABLE_REGION_START_MARKER)
 744            .collect::<Vec<_>>();
 745        anyhow::ensure!(
 746            start_markers.len() == 1,
 747            "expected exactly one start marker, found {}",
 748            start_markers.len()
 749        );
 750
 751        let end_markers = content
 752            .match_indices(EDITABLE_REGION_END_MARKER)
 753            .collect::<Vec<_>>();
 754        anyhow::ensure!(
 755            end_markers.len() == 1,
 756            "expected exactly one end marker, found {}",
 757            end_markers.len()
 758        );
 759
 760        let sof_markers = content
 761            .match_indices(START_OF_FILE_MARKER)
 762            .collect::<Vec<_>>();
 763        anyhow::ensure!(
 764            sof_markers.len() <= 1,
 765            "expected at most one start-of-file marker, found {}",
 766            sof_markers.len()
 767        );
 768
 769        let codefence_start = start_markers[0].0;
 770        let content = &content[codefence_start..];
 771
 772        let newline_ix = content.find('\n').context("could not find newline")?;
 773        let content = &content[newline_ix + 1..];
 774
 775        let codefence_end = content
 776            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 777            .context("could not find end marker")?;
 778        let new_text = &content[..codefence_end];
 779
 780        let old_text = snapshot
 781            .text_for_range(excerpt_range.clone())
 782            .collect::<String>();
 783
 784        Ok(Self::compute_edits(
 785            old_text,
 786            new_text,
 787            excerpt_range.start,
 788            &snapshot,
 789        ))
 790    }
 791
 792    pub fn compute_edits(
 793        old_text: String,
 794        new_text: &str,
 795        offset: usize,
 796        snapshot: &BufferSnapshot,
 797    ) -> Vec<(Range<Anchor>, String)> {
 798        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 799
 800        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 801        let mut old_start = offset;
 802        for change in diff.iter_all_changes() {
 803            let value = change.value();
 804            match change.tag() {
 805                similar::ChangeTag::Equal => {
 806                    old_start += value.len();
 807                }
 808                similar::ChangeTag::Delete => {
 809                    let old_end = old_start + value.len();
 810                    if let Some((last_old_range, _)) = edits.last_mut() {
 811                        if last_old_range.end == old_start {
 812                            last_old_range.end = old_end;
 813                        } else {
 814                            edits.push((old_start..old_end, String::new()));
 815                        }
 816                    } else {
 817                        edits.push((old_start..old_end, String::new()));
 818                    }
 819                    old_start = old_end;
 820                }
 821                similar::ChangeTag::Insert => {
 822                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 823                        if last_old_range.end == old_start {
 824                            last_new_text.push_str(value);
 825                        } else {
 826                            edits.push((old_start..old_start, value.into()));
 827                        }
 828                    } else {
 829                        edits.push((old_start..old_start, value.into()));
 830                    }
 831                }
 832            }
 833        }
 834
 835        edits
 836            .into_iter()
 837            .map(|(mut old_range, new_text)| {
 838                let prefix_len = common_prefix(
 839                    snapshot.chars_for_range(old_range.clone()),
 840                    new_text.chars(),
 841                );
 842                old_range.start += prefix_len;
 843                let suffix_len = common_prefix(
 844                    snapshot.reversed_chars_for_range(old_range.clone()),
 845                    new_text[prefix_len..].chars().rev(),
 846                );
 847                old_range.end = old_range.end.saturating_sub(suffix_len);
 848
 849                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 850                let range = if old_range.is_empty() {
 851                    let anchor = snapshot.anchor_after(old_range.start);
 852                    anchor..anchor
 853                } else {
 854                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 855                };
 856                (range, new_text)
 857            })
 858            .collect()
 859    }
 860
 861    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 862        self.rated_completions.contains(&completion_id)
 863    }
 864
 865    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 866        self.shown_completions.push_front(completion.clone());
 867        if self.shown_completions.len() > 50 {
 868            let completion = self.shown_completions.pop_back().unwrap();
 869            self.rated_completions.remove(&completion.id);
 870        }
 871        cx.notify();
 872    }
 873
 874    pub fn rate_completion(
 875        &mut self,
 876        completion: &InlineCompletion,
 877        rating: InlineCompletionRating,
 878        feedback: String,
 879        cx: &mut Context<Self>,
 880    ) {
 881        self.rated_completions.insert(completion.id);
 882        telemetry::event!(
 883            "Edit Prediction Rated",
 884            rating,
 885            input_events = completion.input_events,
 886            input_excerpt = completion.input_excerpt,
 887            input_outline = completion.input_outline,
 888            output_excerpt = completion.output_excerpt,
 889            feedback
 890        );
 891        self.client.telemetry().flush_events();
 892        cx.notify();
 893    }
 894
 895    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 896        self.shown_completions.iter()
 897    }
 898
 899    pub fn shown_completions_len(&self) -> usize {
 900        self.shown_completions.len()
 901    }
 902
 903    fn report_changes_for_buffer(
 904        &mut self,
 905        buffer: &Entity<Buffer>,
 906        cx: &mut Context<Self>,
 907    ) -> BufferSnapshot {
 908        self.register_buffer(buffer, cx);
 909
 910        let registered_buffer = self
 911            .registered_buffers
 912            .get_mut(&buffer.entity_id())
 913            .unwrap();
 914        let new_snapshot = buffer.read(cx).snapshot();
 915
 916        if new_snapshot.version != registered_buffer.snapshot.version {
 917            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 918            self.push_event(Event::BufferChange {
 919                old_snapshot,
 920                new_snapshot: new_snapshot.clone(),
 921                timestamp: Instant::now(),
 922            });
 923        }
 924
 925        new_snapshot
 926    }
 927
 928    fn load_data_collection_choices() -> DataCollectionChoice {
 929        let choice = KEY_VALUE_STORE
 930            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 931            .log_err()
 932            .flatten();
 933
 934        match choice.as_deref() {
 935            Some("true") => DataCollectionChoice::Enabled,
 936            Some("false") => DataCollectionChoice::Disabled,
 937            Some(_) => {
 938                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
 939                DataCollectionChoice::NotAnswered
 940            }
 941            None => DataCollectionChoice::NotAnswered,
 942        }
 943    }
 944}
 945
 946struct LicenseDetectionWatcher {
 947    is_open_source_rx: watch::Receiver<bool>,
 948    _is_open_source_task: Task<()>,
 949}
 950
 951impl LicenseDetectionWatcher {
 952    pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
 953        let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
 954
 955        let loaded_file_fut = worktree.load_file(Path::new("LICENSE"), cx);
 956
 957        Self {
 958            is_open_source_rx,
 959            _is_open_source_task: cx.spawn(|_, _| async move {
 960                // TODO: Don't display error if file not found
 961                let Some(loaded_file) = loaded_file_fut.await.log_err() else {
 962                    return;
 963                };
 964
 965                let is_loaded_file_open_source_thing: bool =
 966                    is_license_eligible_for_data_collection(&loaded_file.text);
 967
 968                *is_open_source_tx.borrow_mut() = is_loaded_file_open_source_thing;
 969            }),
 970        }
 971    }
 972
 973    /// Answers false until we find out it's open source
 974    pub fn is_open_source(&self) -> bool {
 975        *self.is_open_source_rx.borrow()
 976    }
 977}
 978
 979fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 980    a.zip(b)
 981        .take_while(|(a, b)| a == b)
 982        .map(|(a, _)| a.len_utf8())
 983        .sum()
 984}
 985
 986fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 987    let mut input_outline = String::new();
 988
 989    writeln!(
 990        input_outline,
 991        "```{}",
 992        snapshot
 993            .file()
 994            .map_or(Cow::Borrowed("untitled"), |file| file
 995                .path()
 996                .to_string_lossy())
 997    )
 998    .unwrap();
 999
1000    if let Some(outline) = snapshot.outline(None) {
1001        let guess_size = outline.items.len() * 15;
1002        input_outline.reserve(guess_size);
1003        for item in outline.items.iter() {
1004            let spacing = " ".repeat(item.depth);
1005            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1006        }
1007    }
1008
1009    writeln!(input_outline, "```").unwrap();
1010
1011    input_outline
1012}
1013
1014fn prompt_for_excerpt(
1015    offset: usize,
1016    excerpt_range: &Range<usize>,
1017    mut len_guess: usize,
1018    path: &str,
1019    snapshot: &BufferSnapshot,
1020) -> String {
1021    let point_range = excerpt_range.to_point(snapshot);
1022
1023    // Include one line of extra context before and after editable range, if those lines are non-empty.
1024    let extra_context_before_range =
1025        if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
1026            let range =
1027                (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
1028            len_guess += range.end - range.start;
1029            Some(range)
1030        } else {
1031            None
1032        };
1033    let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
1034        && !snapshot.is_line_blank(point_range.end.row + 1)
1035    {
1036        let range = (point_range.end
1037            ..Point::new(
1038                point_range.end.row + 1,
1039                snapshot.line_len(point_range.end.row + 1),
1040            ))
1041            .to_offset(snapshot);
1042        len_guess += range.end - range.start;
1043        Some(range)
1044    } else {
1045        None
1046    };
1047
1048    let mut prompt_excerpt = String::with_capacity(len_guess);
1049    writeln!(prompt_excerpt, "```{}", path).unwrap();
1050
1051    if excerpt_range.start == 0 {
1052        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
1053    }
1054
1055    if let Some(extra_context_before_range) = extra_context_before_range {
1056        for chunk in snapshot.text_for_range(extra_context_before_range) {
1057            prompt_excerpt.push_str(chunk);
1058        }
1059    }
1060    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
1061    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
1062        prompt_excerpt.push_str(chunk);
1063    }
1064    prompt_excerpt.push_str(CURSOR_MARKER);
1065    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
1066        prompt_excerpt.push_str(chunk);
1067    }
1068    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
1069
1070    if let Some(extra_context_after_range) = extra_context_after_range {
1071        for chunk in snapshot.text_for_range(extra_context_after_range) {
1072            prompt_excerpt.push_str(chunk);
1073        }
1074    }
1075
1076    write!(prompt_excerpt, "\n```").unwrap();
1077    debug_assert!(
1078        prompt_excerpt.len() <= len_guess,
1079        "Excerpt length {} exceeds estimated length {}",
1080        prompt_excerpt.len(),
1081        len_guess
1082    );
1083    prompt_excerpt
1084}
1085
1086fn excerpt_range_for_position(
1087    cursor_point: Point,
1088    byte_limit: usize,
1089    line_limit: u32,
1090    path: &str,
1091    snapshot: &BufferSnapshot,
1092) -> Result<(Range<usize>, usize)> {
1093    let cursor_row = cursor_point.row;
1094    let last_buffer_row = snapshot.max_point().row;
1095
1096    // This is an overestimate because it includes parts of prompt_for_excerpt which are
1097    // conditionally skipped.
1098    let mut len_guess = 0;
1099    len_guess += "```".len() + path.len() + 1;
1100    len_guess += START_OF_FILE_MARKER.len() + 1;
1101    len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
1102    len_guess += CURSOR_MARKER.len();
1103    len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
1104    len_guess += "```".len() + 1;
1105
1106    len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
1107
1108    if len_guess > byte_limit {
1109        return Err(anyhow!("Current line too long to send to model."));
1110    }
1111
1112    let mut excerpt_start_row = cursor_row;
1113    let mut excerpt_end_row = cursor_row;
1114    let mut no_more_before = cursor_row == 0;
1115    let mut no_more_after = cursor_row >= last_buffer_row;
1116    let mut row_delta = 1;
1117    loop {
1118        if !no_more_before {
1119            let row = cursor_point.row - row_delta;
1120            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
1121            let mut new_len_guess = len_guess + line_len;
1122            if row == 0 {
1123                new_len_guess += START_OF_FILE_MARKER.len() + 1;
1124            }
1125            if new_len_guess <= byte_limit {
1126                len_guess = new_len_guess;
1127                excerpt_start_row = row;
1128                if row == 0 {
1129                    no_more_before = true;
1130                }
1131            } else {
1132                no_more_before = true;
1133            }
1134        }
1135        if excerpt_end_row - excerpt_start_row >= line_limit {
1136            break;
1137        }
1138        if !no_more_after {
1139            let row = cursor_point.row + row_delta;
1140            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
1141            let new_len_guess = len_guess + line_len;
1142            if new_len_guess <= byte_limit {
1143                len_guess = new_len_guess;
1144                excerpt_end_row = row;
1145                if row >= last_buffer_row {
1146                    no_more_after = true;
1147                }
1148            } else {
1149                no_more_after = true;
1150            }
1151        }
1152        if excerpt_end_row - excerpt_start_row >= line_limit {
1153            break;
1154        }
1155        if no_more_before && no_more_after {
1156            break;
1157        }
1158        row_delta += 1;
1159    }
1160
1161    let excerpt_start = Point::new(excerpt_start_row, 0);
1162    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
1163    Ok((
1164        excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
1165        len_guess,
1166    ))
1167}
1168
1169fn prompt_for_events<'a>(
1170    events: impl Iterator<Item = &'a Event>,
1171    mut bytes_remaining: usize,
1172) -> String {
1173    let mut result = String::new();
1174    for event in events {
1175        if !result.is_empty() {
1176            result.push('\n');
1177            result.push('\n');
1178        }
1179        let event_string = event.to_prompt();
1180        let len = event_string.len();
1181        if len > PER_EVENT_BYTE_LIMIT {
1182            continue;
1183        }
1184        if len > bytes_remaining {
1185            break;
1186        }
1187        bytes_remaining -= len;
1188        result.push_str(&event_string);
1189    }
1190    result
1191}
1192
1193struct RegisteredBuffer {
1194    snapshot: BufferSnapshot,
1195    _subscriptions: [gpui::Subscription; 2],
1196}
1197
1198#[derive(Clone)]
1199enum Event {
1200    BufferChange {
1201        old_snapshot: BufferSnapshot,
1202        new_snapshot: BufferSnapshot,
1203        timestamp: Instant,
1204    },
1205}
1206
1207impl Event {
1208    fn to_prompt(&self) -> String {
1209        match self {
1210            Event::BufferChange {
1211                old_snapshot,
1212                new_snapshot,
1213                ..
1214            } => {
1215                let mut prompt = String::new();
1216
1217                let old_path = old_snapshot
1218                    .file()
1219                    .map(|f| f.path().as_ref())
1220                    .unwrap_or(Path::new("untitled"));
1221                let new_path = new_snapshot
1222                    .file()
1223                    .map(|f| f.path().as_ref())
1224                    .unwrap_or(Path::new("untitled"));
1225                if old_path != new_path {
1226                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1227                }
1228
1229                let diff =
1230                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1231                        .unified_diff()
1232                        .to_string();
1233                if !diff.is_empty() {
1234                    write!(
1235                        prompt,
1236                        "User edited {:?}:\n```diff\n{}\n```",
1237                        new_path, diff
1238                    )
1239                    .unwrap();
1240                }
1241
1242                prompt
1243            }
1244        }
1245    }
1246}
1247
1248#[derive(Debug, Clone)]
1249struct CurrentInlineCompletion {
1250    buffer_id: EntityId,
1251    completion: InlineCompletion,
1252}
1253
1254impl CurrentInlineCompletion {
1255    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1256        if self.buffer_id != old_completion.buffer_id {
1257            return true;
1258        }
1259
1260        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1261            return true;
1262        };
1263        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1264            return false;
1265        };
1266
1267        if old_edits.len() == 1 && new_edits.len() == 1 {
1268            let (old_range, old_text) = &old_edits[0];
1269            let (new_range, new_text) = &new_edits[0];
1270            new_range == old_range && new_text.starts_with(old_text)
1271        } else {
1272            true
1273        }
1274    }
1275}
1276
1277struct PendingCompletion {
1278    id: usize,
1279    _task: Task<()>,
1280}
1281
1282#[derive(Debug, Clone, Copy)]
1283pub enum DataCollectionChoice {
1284    NotAnswered,
1285    Enabled,
1286    Disabled,
1287}
1288
1289impl DataCollectionChoice {
1290    pub fn is_enabled(self) -> bool {
1291        match self {
1292            Self::Enabled => true,
1293            Self::NotAnswered | Self::Disabled => false,
1294        }
1295    }
1296
1297    pub fn is_answered(self) -> bool {
1298        match self {
1299            Self::Enabled | Self::Disabled => true,
1300            Self::NotAnswered => false,
1301        }
1302    }
1303
1304    pub fn toggle(&self) -> DataCollectionChoice {
1305        match self {
1306            Self::Enabled => Self::Disabled,
1307            Self::Disabled => Self::Enabled,
1308            Self::NotAnswered => Self::Enabled,
1309        }
1310    }
1311}
1312
1313impl From<bool> for DataCollectionChoice {
1314    fn from(value: bool) -> Self {
1315        match value {
1316            true => DataCollectionChoice::Enabled,
1317            false => DataCollectionChoice::Disabled,
1318        }
1319    }
1320}
1321
1322pub struct ProviderDataCollection {
1323    /// When set to None, data collection is not possible in the provider buffer
1324    choice: Option<Entity<DataCollectionChoice>>,
1325    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1326}
1327
1328impl ProviderDataCollection {
1329    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1330        let choice_and_watcher = buffer.and_then(|buffer| {
1331            let file = buffer.read(cx).file()?;
1332
1333            if !file.is_local() || file.is_private() {
1334                return None;
1335            }
1336
1337            let zeta = zeta.read(cx);
1338            let choice = zeta.data_collection_choice.clone();
1339
1340            // Unwrap safety: there should be a watcher for each worktree
1341            let license_detection_watcher = zeta
1342                .license_detection_watchers
1343                .get(&file.worktree_id(cx))
1344                .cloned()?;
1345
1346            Some((choice, license_detection_watcher))
1347        });
1348
1349        if let Some((choice, watcher)) = choice_and_watcher {
1350            ProviderDataCollection {
1351                choice: Some(choice),
1352                license_detection_watcher: Some(watcher),
1353            }
1354        } else {
1355            ProviderDataCollection {
1356                choice: None,
1357                license_detection_watcher: None,
1358            }
1359        }
1360    }
1361
1362    pub fn user_data_collection_choice(&self, cx: &App) -> bool {
1363        self.choice
1364            .as_ref()
1365            .map_or(false, |choice| choice.read(cx).is_enabled())
1366    }
1367
1368    pub fn data_collection_permission(&self, cx: &App) -> bool {
1369        self.choice
1370            .as_ref()
1371            .is_some_and(|choice| choice.read(cx).is_enabled())
1372            && self
1373                .license_detection_watcher
1374                .as_ref()
1375                .is_some_and(|watcher| watcher.is_open_source())
1376    }
1377
1378    pub fn toggle(&mut self, cx: &mut App) {
1379        if let Some(choice) = self.choice.as_mut() {
1380            let new_choice = choice.update(cx, |choice, _cx| {
1381                let new_choice = choice.toggle();
1382                *choice = new_choice;
1383                new_choice
1384            });
1385
1386            db::write_and_log(cx, move || {
1387                KEY_VALUE_STORE.write_kvp(
1388                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1389                    new_choice.is_enabled().to_string(),
1390                )
1391            });
1392        }
1393    }
1394}
1395
1396pub struct ZetaInlineCompletionProvider {
1397    zeta: Entity<Zeta>,
1398    pending_completions: ArrayVec<PendingCompletion, 2>,
1399    next_pending_completion_id: usize,
1400    current_completion: Option<CurrentInlineCompletion>,
1401    /// None if this is entirely disabled for this provider
1402    provider_data_collection: ProviderDataCollection,
1403    last_request_timestamp: Instant,
1404}
1405
1406impl ZetaInlineCompletionProvider {
1407    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1408
1409    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1410        Self {
1411            zeta,
1412            pending_completions: ArrayVec::new(),
1413            next_pending_completion_id: 0,
1414            current_completion: None,
1415            provider_data_collection,
1416            last_request_timestamp: Instant::now(),
1417        }
1418    }
1419}
1420
1421impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1422    fn name() -> &'static str {
1423        "zed-predict"
1424    }
1425
1426    fn display_name() -> &'static str {
1427        "Zed's Edit Predictions"
1428    }
1429
1430    fn show_completions_in_menu() -> bool {
1431        true
1432    }
1433
1434    fn show_completions_in_normal_mode() -> bool {
1435        true
1436    }
1437
1438    fn show_tab_accept_marker() -> bool {
1439        true
1440    }
1441
1442    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1443        if self
1444            .provider_data_collection
1445            .user_data_collection_choice(cx)
1446        {
1447            DataCollectionState::Enabled
1448        } else {
1449            DataCollectionState::Disabled
1450        }
1451    }
1452
1453    fn toggle_data_collection(&mut self, cx: &mut App) {
1454        self.provider_data_collection.toggle(cx);
1455    }
1456
1457    fn is_enabled(
1458        &self,
1459        buffer: &Entity<Buffer>,
1460        cursor_position: language::Anchor,
1461        cx: &App,
1462    ) -> bool {
1463        let buffer = buffer.read(cx);
1464        let file = buffer.file();
1465        let language = buffer.language_at(cursor_position);
1466        let settings = all_language_settings(file, cx);
1467        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1468    }
1469
1470    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1471        !self.zeta.read(cx).tos_accepted
1472    }
1473
1474    fn is_refreshing(&self) -> bool {
1475        !self.pending_completions.is_empty()
1476    }
1477
1478    fn refresh(
1479        &mut self,
1480        buffer: Entity<Buffer>,
1481        position: language::Anchor,
1482        _debounce: bool,
1483        cx: &mut Context<Self>,
1484    ) {
1485        if !self.zeta.read(cx).tos_accepted {
1486            return;
1487        }
1488
1489        if let Some(current_completion) = self.current_completion.as_ref() {
1490            let snapshot = buffer.read(cx).snapshot();
1491            if current_completion
1492                .completion
1493                .interpolate(&snapshot)
1494                .is_some()
1495            {
1496                return;
1497            }
1498        }
1499
1500        let pending_completion_id = self.next_pending_completion_id;
1501        self.next_pending_completion_id += 1;
1502        let data_collection_permission =
1503            self.provider_data_collection.data_collection_permission(cx);
1504        let last_request_timestamp = self.last_request_timestamp;
1505
1506        let task = cx.spawn(|this, mut cx| async move {
1507            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1508                .checked_duration_since(Instant::now())
1509            {
1510                cx.background_executor().timer(timeout).await;
1511            }
1512
1513            let completion_request = this.update(&mut cx, |this, cx| {
1514                this.last_request_timestamp = Instant::now();
1515                this.zeta.update(cx, |zeta, cx| {
1516                    zeta.request_completion(&buffer, position, data_collection_permission, cx)
1517                })
1518            });
1519
1520            let completion = match completion_request {
1521                Ok(completion_request) => {
1522                    let completion_request = completion_request.await;
1523                    completion_request.map(|c| {
1524                        c.map(|completion| CurrentInlineCompletion {
1525                            buffer_id: buffer.entity_id(),
1526                            completion,
1527                        })
1528                    })
1529                }
1530                Err(error) => Err(error),
1531            };
1532            let Some(new_completion) = completion
1533                .context("edit prediction failed")
1534                .log_err()
1535                .flatten()
1536            else {
1537                this.update(&mut cx, |this, cx| {
1538                    if this.pending_completions[0].id == pending_completion_id {
1539                        this.pending_completions.remove(0);
1540                    } else {
1541                        this.pending_completions.clear();
1542                    }
1543
1544                    cx.notify();
1545                })
1546                .ok();
1547                return;
1548            };
1549
1550            this.update(&mut cx, |this, cx| {
1551                if this.pending_completions[0].id == pending_completion_id {
1552                    this.pending_completions.remove(0);
1553                } else {
1554                    this.pending_completions.clear();
1555                }
1556
1557                if let Some(old_completion) = this.current_completion.as_ref() {
1558                    let snapshot = buffer.read(cx).snapshot();
1559                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1560                        this.zeta.update(cx, |zeta, cx| {
1561                            zeta.completion_shown(&new_completion.completion, cx);
1562                        });
1563                        this.current_completion = Some(new_completion);
1564                    }
1565                } else {
1566                    this.zeta.update(cx, |zeta, cx| {
1567                        zeta.completion_shown(&new_completion.completion, cx);
1568                    });
1569                    this.current_completion = Some(new_completion);
1570                }
1571
1572                cx.notify();
1573            })
1574            .ok();
1575        });
1576
1577        // We always maintain at most two pending completions. When we already
1578        // have two, we replace the newest one.
1579        if self.pending_completions.len() <= 1 {
1580            self.pending_completions.push(PendingCompletion {
1581                id: pending_completion_id,
1582                _task: task,
1583            });
1584        } else if self.pending_completions.len() == 2 {
1585            self.pending_completions.pop();
1586            self.pending_completions.push(PendingCompletion {
1587                id: pending_completion_id,
1588                _task: task,
1589            });
1590        }
1591    }
1592
1593    fn cycle(
1594        &mut self,
1595        _buffer: Entity<Buffer>,
1596        _cursor_position: language::Anchor,
1597        _direction: inline_completion::Direction,
1598        _cx: &mut Context<Self>,
1599    ) {
1600        // Right now we don't support cycling.
1601    }
1602
1603    fn accept(&mut self, _cx: &mut Context<Self>) {
1604        self.pending_completions.clear();
1605    }
1606
1607    fn discard(&mut self, _cx: &mut Context<Self>) {
1608        self.pending_completions.clear();
1609        self.current_completion.take();
1610    }
1611
1612    fn suggest(
1613        &mut self,
1614        buffer: &Entity<Buffer>,
1615        cursor_position: language::Anchor,
1616        cx: &mut Context<Self>,
1617    ) -> Option<inline_completion::InlineCompletion> {
1618        let CurrentInlineCompletion {
1619            buffer_id,
1620            completion,
1621            ..
1622        } = self.current_completion.as_mut()?;
1623
1624        // Invalidate previous completion if it was generated for a different buffer.
1625        if *buffer_id != buffer.entity_id() {
1626            self.current_completion.take();
1627            return None;
1628        }
1629
1630        let buffer = buffer.read(cx);
1631        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1632            self.current_completion.take();
1633            return None;
1634        };
1635
1636        let cursor_row = cursor_position.to_point(buffer).row;
1637        let (closest_edit_ix, (closest_edit_range, _)) =
1638            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1639                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1640                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1641                cmp::min(distance_from_start, distance_from_end)
1642            })?;
1643
1644        let mut edit_start_ix = closest_edit_ix;
1645        for (range, _) in edits[..edit_start_ix].iter().rev() {
1646            let distance_from_closest_edit =
1647                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1648            if distance_from_closest_edit <= 1 {
1649                edit_start_ix -= 1;
1650            } else {
1651                break;
1652            }
1653        }
1654
1655        let mut edit_end_ix = closest_edit_ix + 1;
1656        for (range, _) in &edits[edit_end_ix..] {
1657            let distance_from_closest_edit =
1658                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1659            if distance_from_closest_edit <= 1 {
1660                edit_end_ix += 1;
1661            } else {
1662                break;
1663            }
1664        }
1665
1666        Some(inline_completion::InlineCompletion {
1667            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1668            edit_preview: Some(completion.edit_preview.clone()),
1669        })
1670    }
1671}
1672
1673#[cfg(test)]
1674mod tests {
1675    use client::test::FakeServer;
1676    use clock::FakeSystemClock;
1677    use gpui::TestAppContext;
1678    use http_client::FakeHttpClient;
1679    use indoc::indoc;
1680    use language_models::RefreshLlmTokenListener;
1681    use rpc::proto;
1682    use settings::SettingsStore;
1683
1684    use super::*;
1685
1686    #[gpui::test]
1687    async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1688        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1689        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1690            to_completion_edits(
1691                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1692                &buffer,
1693                cx,
1694            )
1695            .into()
1696        });
1697
1698        let edit_preview = cx
1699            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1700            .await;
1701
1702        let completion = InlineCompletion {
1703            edits,
1704            edit_preview,
1705            path: Path::new("").into(),
1706            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1707            id: InlineCompletionId::new(),
1708            excerpt_range: 0..0,
1709            cursor_offset: 0,
1710            input_outline: "".into(),
1711            input_events: "".into(),
1712            input_excerpt: "".into(),
1713            output_excerpt: "".into(),
1714            request_sent_at: Instant::now(),
1715            response_received_at: Instant::now(),
1716        };
1717
1718        cx.update(|cx| {
1719            assert_eq!(
1720                from_completion_edits(
1721                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1722                    &buffer,
1723                    cx
1724                ),
1725                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1726            );
1727
1728            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1729            assert_eq!(
1730                from_completion_edits(
1731                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1732                    &buffer,
1733                    cx
1734                ),
1735                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1736            );
1737
1738            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1739            assert_eq!(
1740                from_completion_edits(
1741                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1742                    &buffer,
1743                    cx
1744                ),
1745                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1746            );
1747
1748            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1749            assert_eq!(
1750                from_completion_edits(
1751                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1752                    &buffer,
1753                    cx
1754                ),
1755                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1756            );
1757
1758            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1759            assert_eq!(
1760                from_completion_edits(
1761                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1762                    &buffer,
1763                    cx
1764                ),
1765                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1766            );
1767
1768            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1769            assert_eq!(
1770                from_completion_edits(
1771                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1772                    &buffer,
1773                    cx
1774                ),
1775                vec![(9..11, "".to_string())]
1776            );
1777
1778            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1779            assert_eq!(
1780                from_completion_edits(
1781                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1782                    &buffer,
1783                    cx
1784                ),
1785                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1786            );
1787
1788            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1789            assert_eq!(
1790                from_completion_edits(
1791                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1792                    &buffer,
1793                    cx
1794                ),
1795                vec![(4..4, "M".to_string())]
1796            );
1797
1798            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1799            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1800        })
1801    }
1802
1803    #[gpui::test]
1804    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1805        cx.update(|cx| {
1806            let settings_store = SettingsStore::test(cx);
1807            cx.set_global(settings_store);
1808            client::init_settings(cx);
1809        });
1810
1811        let buffer_content = "lorem\n";
1812        let completion_response = indoc! {"
1813            ```animals.js
1814            <|start_of_file|>
1815            <|editable_region_start|>
1816            lorem
1817            ipsum
1818            <|editable_region_end|>
1819            ```"};
1820
1821        let http_client = FakeHttpClient::create(move |_| async move {
1822            Ok(http_client::Response::builder()
1823                .status(200)
1824                .body(
1825                    serde_json::to_string(&PredictEditsResponse {
1826                        output_excerpt: completion_response.to_string(),
1827                    })
1828                    .unwrap()
1829                    .into(),
1830                )
1831                .unwrap())
1832        });
1833
1834        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1835        cx.update(|cx| {
1836            RefreshLlmTokenListener::register(client.clone(), cx);
1837        });
1838        let server = FakeServer::for_client(42, &client, cx).await;
1839        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1840        let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1841
1842        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1843        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1844        let completion_task = zeta.update(cx, |zeta, cx| {
1845            zeta.request_completion(&buffer, cursor, false, cx)
1846        });
1847
1848        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1849        server.respond(
1850            token_request.receipt(),
1851            proto::GetLlmTokenResponse { token: "".into() },
1852        );
1853
1854        let completion = completion_task.await.unwrap().unwrap();
1855        buffer.update(cx, |buffer, cx| {
1856            buffer.edit(completion.edits.iter().cloned(), None, cx)
1857        });
1858        assert_eq!(
1859            buffer.read_with(cx, |buffer, _| buffer.text()),
1860            "lorem\nipsum"
1861        );
1862    }
1863
1864    fn to_completion_edits(
1865        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1866        buffer: &Entity<Buffer>,
1867        cx: &App,
1868    ) -> Vec<(Range<Anchor>, String)> {
1869        let buffer = buffer.read(cx);
1870        iterator
1871            .into_iter()
1872            .map(|(range, text)| {
1873                (
1874                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1875                    text,
1876                )
1877            })
1878            .collect()
1879    }
1880
1881    fn from_completion_edits(
1882        editor_edits: &[(Range<Anchor>, String)],
1883        buffer: &Entity<Buffer>,
1884        cx: &App,
1885    ) -> Vec<(Range<usize>, String)> {
1886        let buffer = buffer.read(cx);
1887        editor_edits
1888            .iter()
1889            .map(|(range, text)| {
1890                (
1891                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1892                    text.clone(),
1893                )
1894            })
1895            .collect()
1896    }
1897
1898    #[ctor::ctor]
1899    fn init_logger() {
1900        if std::env::var("RUST_LOG").is_ok() {
1901            env_logger::init();
1902        }
1903    }
1904}