zeta.rs

   1mod completion_diff_element;
   2mod persistence;
   3mod rate_completion_modal;
   4
   5pub(crate) use completion_diff_element::*;
   6use db::kvp::KEY_VALUE_STORE;
   7use inline_completion::DataCollectionState;
   8pub use rate_completion_modal::*;
   9
  10use anyhow::{anyhow, Context as _, Result};
  11use arrayvec::ArrayVec;
  12use client::{Client, UserStore};
  13use collections::{HashMap, HashSet, VecDeque};
  14use futures::AsyncReadExt;
  15use gpui::{
  16    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
  17    WeakEntity,
  18};
  19use http_client::{HttpClient, Method};
  20use language::{
  21    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
  22    Point, ToOffset, ToPoint,
  23};
  24use language_models::LlmApiToken;
  25use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  26use serde::{Deserialize, Serialize};
  27use std::{
  28    borrow::Cow,
  29    cmp, env,
  30    fmt::Write,
  31    future::Future,
  32    mem,
  33    ops::Range,
  34    path::{Path, PathBuf},
  35    sync::Arc,
  36    time::{Duration, Instant},
  37};
  38use telemetry_events::InlineCompletionRating;
  39use util::ResultExt;
  40use uuid::Uuid;
  41use workspace::{
  42    notifications::{simple_message_notification::MessageNotification, NotificationId},
  43    Workspace,
  44};
  45
  46const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  47const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  48const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  49const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  50const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  51const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
  52    "zed_predict_data_collection_never_ask_again";
  53
  54actions!(edit_prediction, [ClearHistory]);
  55
  56#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  57pub struct InlineCompletionId(Uuid);
  58
  59impl From<InlineCompletionId> for gpui::ElementId {
  60    fn from(value: InlineCompletionId) -> Self {
  61        gpui::ElementId::Uuid(value.0)
  62    }
  63}
  64
  65impl std::fmt::Display for InlineCompletionId {
  66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  67        write!(f, "{}", self.0)
  68    }
  69}
  70
  71impl InlineCompletionId {
  72    fn new() -> Self {
  73        Self(Uuid::new_v4())
  74    }
  75}
  76
  77#[derive(Clone)]
  78struct ZetaGlobal(Entity<Zeta>);
  79
  80impl Global for ZetaGlobal {}
  81
  82#[derive(Clone)]
  83pub struct InlineCompletion {
  84    id: InlineCompletionId,
  85    path: Arc<Path>,
  86    excerpt_range: Range<usize>,
  87    cursor_offset: usize,
  88    edits: Arc<[(Range<Anchor>, String)]>,
  89    snapshot: BufferSnapshot,
  90    input_outline: Arc<str>,
  91    input_events: Arc<str>,
  92    input_excerpt: Arc<str>,
  93    output_excerpt: Arc<str>,
  94    request_sent_at: Instant,
  95    response_received_at: Instant,
  96}
  97
  98impl InlineCompletion {
  99    fn latency(&self) -> Duration {
 100        self.response_received_at
 101            .duration_since(self.request_sent_at)
 102    }
 103
 104    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 105        let mut edits = Vec::new();
 106
 107        let mut user_edits = new_snapshot
 108            .edits_since::<usize>(&self.snapshot.version)
 109            .peekable();
 110        for (model_old_range, model_new_text) in self.edits.iter() {
 111            let model_offset_range = model_old_range.to_offset(&self.snapshot);
 112            while let Some(next_user_edit) = user_edits.peek() {
 113                if next_user_edit.old.end < model_offset_range.start {
 114                    user_edits.next();
 115                } else {
 116                    break;
 117                }
 118            }
 119
 120            if let Some(user_edit) = user_edits.peek() {
 121                if user_edit.old.start > model_offset_range.end {
 122                    edits.push((model_old_range.clone(), model_new_text.clone()));
 123                } else if user_edit.old == model_offset_range {
 124                    let user_new_text = new_snapshot
 125                        .text_for_range(user_edit.new.clone())
 126                        .collect::<String>();
 127
 128                    if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 129                        if !model_suffix.is_empty() {
 130                            edits.push((
 131                                new_snapshot.anchor_after(user_edit.new.end)
 132                                    ..new_snapshot.anchor_before(user_edit.new.end),
 133                                model_suffix.into(),
 134                            ));
 135                        }
 136
 137                        user_edits.next();
 138                    } else {
 139                        return None;
 140                    }
 141                } else {
 142                    return None;
 143                }
 144            } else {
 145                edits.push((model_old_range.clone(), model_new_text.clone()));
 146            }
 147        }
 148
 149        if edits.is_empty() {
 150            None
 151        } else {
 152            Some(edits)
 153        }
 154    }
 155}
 156
 157impl std::fmt::Debug for InlineCompletion {
 158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 159        f.debug_struct("InlineCompletion")
 160            .field("id", &self.id)
 161            .field("path", &self.path)
 162            .field("edits", &self.edits)
 163            .finish_non_exhaustive()
 164    }
 165}
 166
 167pub struct Zeta {
 168    client: Arc<Client>,
 169    events: VecDeque<Event>,
 170    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 171    shown_completions: VecDeque<InlineCompletion>,
 172    rated_completions: HashSet<InlineCompletionId>,
 173    data_collection_preferences: DataCollectionPreferences,
 174    llm_token: LlmApiToken,
 175    _llm_token_subscription: Subscription,
 176    tos_accepted: bool, // Terms of service accepted
 177    _user_store_subscription: Subscription,
 178}
 179
 180impl Zeta {
 181    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 182        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 183    }
 184
 185    pub fn register(
 186        client: Arc<Client>,
 187        user_store: Entity<UserStore>,
 188        cx: &mut App,
 189    ) -> Entity<Self> {
 190        Self::global(cx).unwrap_or_else(|| {
 191            let model = cx.new(|cx| Self::new(client, user_store, cx));
 192            cx.set_global(ZetaGlobal(model.clone()));
 193            model
 194        })
 195    }
 196
 197    pub fn clear_history(&mut self) {
 198        self.events.clear();
 199    }
 200
 201    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 202        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 203        Self {
 204            client,
 205            events: VecDeque::new(),
 206            shown_completions: VecDeque::new(),
 207            rated_completions: HashSet::default(),
 208            registered_buffers: HashMap::default(),
 209            data_collection_preferences: Self::load_data_collection_preferences(cx),
 210            llm_token: LlmApiToken::default(),
 211            _llm_token_subscription: cx.subscribe(
 212                &refresh_llm_token_listener,
 213                |this, _listener, _event, cx| {
 214                    let client = this.client.clone();
 215                    let llm_token = this.llm_token.clone();
 216                    cx.spawn(|_this, _cx| async move {
 217                        llm_token.refresh(&client).await?;
 218                        anyhow::Ok(())
 219                    })
 220                    .detach_and_log_err(cx);
 221                },
 222            ),
 223            tos_accepted: user_store
 224                .read(cx)
 225                .current_user_has_accepted_terms()
 226                .unwrap_or(false),
 227            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 228                match event {
 229                    client::user::Event::PrivateUserInfoUpdated => {
 230                        this.tos_accepted = user_store
 231                            .read(cx)
 232                            .current_user_has_accepted_terms()
 233                            .unwrap_or(false);
 234                    }
 235                    _ => {}
 236                }
 237            }),
 238        }
 239    }
 240
 241    fn push_event(&mut self, event: Event) {
 242        const MAX_EVENT_COUNT: usize = 16;
 243
 244        if let Some(Event::BufferChange {
 245            new_snapshot: last_new_snapshot,
 246            timestamp: last_timestamp,
 247            ..
 248        }) = self.events.back_mut()
 249        {
 250            // Coalesce edits for the same buffer when they happen one after the other.
 251            let Event::BufferChange {
 252                old_snapshot,
 253                new_snapshot,
 254                timestamp,
 255            } = &event;
 256
 257            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 258                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 259                && old_snapshot.version == last_new_snapshot.version
 260            {
 261                *last_new_snapshot = new_snapshot.clone();
 262                *last_timestamp = *timestamp;
 263                return;
 264            }
 265        }
 266
 267        self.events.push_back(event);
 268        if self.events.len() >= MAX_EVENT_COUNT {
 269            self.events.drain(..MAX_EVENT_COUNT / 2);
 270        }
 271    }
 272
 273    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 274        let buffer_id = buffer.entity_id();
 275        let weak_buffer = buffer.downgrade();
 276
 277        if let std::collections::hash_map::Entry::Vacant(entry) =
 278            self.registered_buffers.entry(buffer_id)
 279        {
 280            let snapshot = buffer.read(cx).snapshot();
 281
 282            entry.insert(RegisteredBuffer {
 283                snapshot,
 284                _subscriptions: [
 285                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 286                        this.handle_buffer_event(buffer, event, cx);
 287                    }),
 288                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 289                        this.registered_buffers.remove(&weak_buffer.entity_id());
 290                    }),
 291                ],
 292            });
 293        };
 294    }
 295
 296    fn handle_buffer_event(
 297        &mut self,
 298        buffer: Entity<Buffer>,
 299        event: &language::BufferEvent,
 300        cx: &mut Context<Self>,
 301    ) {
 302        if let language::BufferEvent::Edited = event {
 303            self.report_changes_for_buffer(&buffer, cx);
 304        }
 305    }
 306
 307    pub fn request_completion_impl<F, R>(
 308        &mut self,
 309        buffer: &Entity<Buffer>,
 310        cursor: language::Anchor,
 311        can_collect_data: bool,
 312        cx: &mut Context<Self>,
 313        perform_predict_edits: F,
 314    ) -> Task<Result<Option<InlineCompletion>>>
 315    where
 316        F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
 317        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 318    {
 319        let snapshot = self.report_changes_for_buffer(buffer, cx);
 320        let point = cursor.to_point(&snapshot);
 321        let offset = point.to_offset(&snapshot);
 322        let excerpt_range = excerpt_range_for_position(point, &snapshot);
 323        let events = self.events.clone();
 324        let path = snapshot
 325            .file()
 326            .map(|f| Arc::from(f.full_path(cx).as_path()))
 327            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 328
 329        let client = self.client.clone();
 330        let llm_token = self.llm_token.clone();
 331
 332        cx.spawn(|_, cx| async move {
 333            let request_sent_at = Instant::now();
 334
 335            let (input_events, input_excerpt, input_outline) = cx
 336                .background_executor()
 337                .spawn({
 338                    let snapshot = snapshot.clone();
 339                    let excerpt_range = excerpt_range.clone();
 340                    async move {
 341                        let mut input_events = String::new();
 342                        for event in events {
 343                            if !input_events.is_empty() {
 344                                input_events.push('\n');
 345                                input_events.push('\n');
 346                            }
 347                            input_events.push_str(&event.to_prompt());
 348                        }
 349
 350                        let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
 351                        let input_outline = prompt_for_outline(&snapshot);
 352
 353                        (input_events, input_excerpt, input_outline)
 354                    }
 355                })
 356                .await;
 357
 358            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 359
 360            let body = PredictEditsParams {
 361                input_events: input_events.clone(),
 362                input_excerpt: input_excerpt.clone(),
 363                outline: Some(input_outline.clone()),
 364                can_collect_data,
 365            };
 366
 367            let response = perform_predict_edits(client, llm_token, body).await?;
 368
 369            let output_excerpt = response.output_excerpt;
 370            log::debug!("completion response: {}", output_excerpt);
 371
 372            Self::process_completion_response(
 373                output_excerpt,
 374                &snapshot,
 375                excerpt_range,
 376                offset,
 377                path,
 378                input_outline,
 379                input_events,
 380                input_excerpt,
 381                request_sent_at,
 382                &cx,
 383            )
 384            .await
 385        })
 386    }
 387
 388    // Generates several example completions of various states to fill the Zeta completion modal
 389    #[cfg(any(test, feature = "test-support"))]
 390    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 391        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 392            And maybe a short line
 393
 394            Then a few lines
 395
 396            and then another
 397            "#};
 398
 399        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 400        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 401
 402        let completion_tasks = vec![
 403            self.fake_completion(
 404                &buffer,
 405                position,
 406                PredictEditsResponse {
 407                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 408a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 409[here's an edit]
 410And maybe a short line
 411Then a few lines
 412and then another
 413{EDITABLE_REGION_END_MARKER}
 414                        ", ),
 415                },
 416                cx,
 417            ),
 418            self.fake_completion(
 419                &buffer,
 420                position,
 421                PredictEditsResponse {
 422                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 423a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 424And maybe a short line
 425[and another edit]
 426Then a few lines
 427and then another
 428{EDITABLE_REGION_END_MARKER}
 429                        "#),
 430                },
 431                cx,
 432            ),
 433            self.fake_completion(
 434                &buffer,
 435                position,
 436                PredictEditsResponse {
 437                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 438a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 439And maybe a short line
 440
 441Then a few lines
 442
 443and then another
 444{EDITABLE_REGION_END_MARKER}
 445                        "#),
 446                },
 447                cx,
 448            ),
 449            self.fake_completion(
 450                &buffer,
 451                position,
 452                PredictEditsResponse {
 453                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 454a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 455And maybe a short line
 456
 457Then a few lines
 458
 459and then another
 460{EDITABLE_REGION_END_MARKER}
 461                        "#),
 462                },
 463                cx,
 464            ),
 465            self.fake_completion(
 466                &buffer,
 467                position,
 468                PredictEditsResponse {
 469                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 470a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 471And maybe a short line
 472Then a few lines
 473[a third completion]
 474and then another
 475{EDITABLE_REGION_END_MARKER}
 476                        "#),
 477                },
 478                cx,
 479            ),
 480            self.fake_completion(
 481                &buffer,
 482                position,
 483                PredictEditsResponse {
 484                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 485a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 486And maybe a short line
 487and then another
 488[fourth completion example]
 489{EDITABLE_REGION_END_MARKER}
 490                        "#),
 491                },
 492                cx,
 493            ),
 494            self.fake_completion(
 495                &buffer,
 496                position,
 497                PredictEditsResponse {
 498                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 499a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 500And maybe a short line
 501Then a few lines
 502and then another
 503[fifth and final completion]
 504{EDITABLE_REGION_END_MARKER}
 505                        "#),
 506                },
 507                cx,
 508            ),
 509        ];
 510
 511        cx.spawn(|zeta, mut cx| async move {
 512            for task in completion_tasks {
 513                task.await.unwrap();
 514            }
 515
 516            zeta.update(&mut cx, |zeta, _cx| {
 517                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 518                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 519            })
 520            .ok();
 521        })
 522    }
 523
 524    #[cfg(any(test, feature = "test-support"))]
 525    pub fn fake_completion(
 526        &mut self,
 527        buffer: &Entity<Buffer>,
 528        position: language::Anchor,
 529        response: PredictEditsResponse,
 530        cx: &mut Context<Self>,
 531    ) -> Task<Result<Option<InlineCompletion>>> {
 532        use std::future::ready;
 533
 534        self.request_completion_impl(buffer, position, false, cx, |_, _, _| ready(Ok(response)))
 535    }
 536
 537    pub fn request_completion(
 538        &mut self,
 539        buffer: &Entity<Buffer>,
 540        position: language::Anchor,
 541        can_collect_data: bool,
 542        cx: &mut Context<Self>,
 543    ) -> Task<Result<Option<InlineCompletion>>> {
 544        self.request_completion_impl(
 545            buffer,
 546            position,
 547            can_collect_data,
 548            cx,
 549            Self::perform_predict_edits,
 550        )
 551    }
 552
 553    fn perform_predict_edits(
 554        client: Arc<Client>,
 555        llm_token: LlmApiToken,
 556        body: PredictEditsParams,
 557    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 558        async move {
 559            let http_client = client.http_client();
 560            let mut token = llm_token.acquire(&client).await?;
 561            let mut did_retry = false;
 562
 563            loop {
 564                let request_builder = http_client::Request::builder();
 565                let request = request_builder
 566                    .method(Method::POST)
 567                    .uri(
 568                        http_client
 569                            .build_zed_llm_url("/predict_edits", &[])?
 570                            .as_ref(),
 571                    )
 572                    .header("Content-Type", "application/json")
 573                    .header("Authorization", format!("Bearer {}", token))
 574                    .body(serde_json::to_string(&body)?.into())?;
 575
 576                let mut response = http_client.send(request).await?;
 577
 578                if response.status().is_success() {
 579                    let mut body = String::new();
 580                    response.body_mut().read_to_string(&mut body).await?;
 581                    return Ok(serde_json::from_str(&body)?);
 582                } else if !did_retry
 583                    && response
 584                        .headers()
 585                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 586                        .is_some()
 587                {
 588                    did_retry = true;
 589                    token = llm_token.refresh(&client).await?;
 590                } else {
 591                    let mut body = String::new();
 592                    response.body_mut().read_to_string(&mut body).await?;
 593                    return Err(anyhow!(
 594                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 595                        response.status(),
 596                        body
 597                    ));
 598                }
 599            }
 600        }
 601    }
 602
 603    #[allow(clippy::too_many_arguments)]
 604    fn process_completion_response(
 605        output_excerpt: String,
 606        snapshot: &BufferSnapshot,
 607        excerpt_range: Range<usize>,
 608        cursor_offset: usize,
 609        path: Arc<Path>,
 610        input_outline: String,
 611        input_events: String,
 612        input_excerpt: String,
 613        request_sent_at: Instant,
 614        cx: &AsyncApp,
 615    ) -> Task<Result<Option<InlineCompletion>>> {
 616        let snapshot = snapshot.clone();
 617        cx.background_executor().spawn(async move {
 618            let content = output_excerpt.replace(CURSOR_MARKER, "");
 619
 620            let start_markers = content
 621                .match_indices(EDITABLE_REGION_START_MARKER)
 622                .collect::<Vec<_>>();
 623            anyhow::ensure!(
 624                start_markers.len() == 1,
 625                "expected exactly one start marker, found {}",
 626                start_markers.len()
 627            );
 628
 629            let codefence_start = start_markers[0].0;
 630            let content = &content[codefence_start..];
 631
 632            let newline_ix = content.find('\n').context("could not find newline")?;
 633            let content = &content[newline_ix + 1..];
 634
 635            let codefence_end = content
 636                .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 637                .context("could not find end marker")?;
 638            let new_text = &content[..codefence_end];
 639
 640            let old_text = snapshot
 641                .text_for_range(excerpt_range.clone())
 642                .collect::<String>();
 643
 644            let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
 645
 646            Ok(Some(InlineCompletion {
 647                id: InlineCompletionId::new(),
 648                path,
 649                excerpt_range,
 650                cursor_offset,
 651                edits: edits.into(),
 652                snapshot: snapshot.clone(),
 653                input_outline: input_outline.into(),
 654                input_events: input_events.into(),
 655                input_excerpt: input_excerpt.into(),
 656                output_excerpt: output_excerpt.into(),
 657                request_sent_at,
 658                response_received_at: Instant::now(),
 659            }))
 660        })
 661    }
 662
 663    pub fn compute_edits(
 664        old_text: String,
 665        new_text: &str,
 666        offset: usize,
 667        snapshot: &BufferSnapshot,
 668    ) -> Vec<(Range<Anchor>, String)> {
 669        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 670
 671        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 672        let mut old_start = offset;
 673        for change in diff.iter_all_changes() {
 674            let value = change.value();
 675            match change.tag() {
 676                similar::ChangeTag::Equal => {
 677                    old_start += value.len();
 678                }
 679                similar::ChangeTag::Delete => {
 680                    let old_end = old_start + value.len();
 681                    if let Some((last_old_range, _)) = edits.last_mut() {
 682                        if last_old_range.end == old_start {
 683                            last_old_range.end = old_end;
 684                        } else {
 685                            edits.push((old_start..old_end, String::new()));
 686                        }
 687                    } else {
 688                        edits.push((old_start..old_end, String::new()));
 689                    }
 690                    old_start = old_end;
 691                }
 692                similar::ChangeTag::Insert => {
 693                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 694                        if last_old_range.end == old_start {
 695                            last_new_text.push_str(value);
 696                        } else {
 697                            edits.push((old_start..old_start, value.into()));
 698                        }
 699                    } else {
 700                        edits.push((old_start..old_start, value.into()));
 701                    }
 702                }
 703            }
 704        }
 705
 706        edits
 707            .into_iter()
 708            .map(|(mut old_range, new_text)| {
 709                let prefix_len = common_prefix(
 710                    snapshot.chars_for_range(old_range.clone()),
 711                    new_text.chars(),
 712                );
 713                old_range.start += prefix_len;
 714                let suffix_len = common_prefix(
 715                    snapshot.reversed_chars_for_range(old_range.clone()),
 716                    new_text[prefix_len..].chars().rev(),
 717                );
 718                old_range.end = old_range.end.saturating_sub(suffix_len);
 719
 720                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 721                (
 722                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
 723                    new_text,
 724                )
 725            })
 726            .collect()
 727    }
 728
 729    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 730        self.rated_completions.contains(&completion_id)
 731    }
 732
 733    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 734        self.shown_completions.push_front(completion.clone());
 735        if self.shown_completions.len() > 50 {
 736            let completion = self.shown_completions.pop_back().unwrap();
 737            self.rated_completions.remove(&completion.id);
 738        }
 739        cx.notify();
 740    }
 741
 742    pub fn rate_completion(
 743        &mut self,
 744        completion: &InlineCompletion,
 745        rating: InlineCompletionRating,
 746        feedback: String,
 747        cx: &mut Context<Self>,
 748    ) {
 749        self.rated_completions.insert(completion.id);
 750        telemetry::event!(
 751            "Inline Completion Rated",
 752            rating,
 753            input_events = completion.input_events,
 754            input_excerpt = completion.input_excerpt,
 755            input_outline = completion.input_outline,
 756            output_excerpt = completion.output_excerpt,
 757            feedback
 758        );
 759        self.client.telemetry().flush_events();
 760        cx.notify();
 761    }
 762
 763    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 764        self.shown_completions.iter()
 765    }
 766
 767    pub fn shown_completions_len(&self) -> usize {
 768        self.shown_completions.len()
 769    }
 770
 771    fn report_changes_for_buffer(
 772        &mut self,
 773        buffer: &Entity<Buffer>,
 774        cx: &mut Context<Self>,
 775    ) -> BufferSnapshot {
 776        self.register_buffer(buffer, cx);
 777
 778        let registered_buffer = self
 779            .registered_buffers
 780            .get_mut(&buffer.entity_id())
 781            .unwrap();
 782        let new_snapshot = buffer.read(cx).snapshot();
 783
 784        if new_snapshot.version != registered_buffer.snapshot.version {
 785            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 786            self.push_event(Event::BufferChange {
 787                old_snapshot,
 788                new_snapshot: new_snapshot.clone(),
 789                timestamp: Instant::now(),
 790            });
 791        }
 792
 793        new_snapshot
 794    }
 795
 796    pub fn data_collection_choice_at(&self, path: &Path) -> DataCollectionChoice {
 797        match self.data_collection_preferences.per_worktree.get(path) {
 798            Some(true) => DataCollectionChoice::Enabled,
 799            Some(false) => DataCollectionChoice::Disabled,
 800            None => DataCollectionChoice::NotAnswered,
 801        }
 802    }
 803
 804    fn update_data_collection_choice_for_worktree(
 805        &mut self,
 806        absolute_path_of_project_worktree: PathBuf,
 807        can_collect_data: bool,
 808        cx: &mut Context<Self>,
 809    ) {
 810        self.data_collection_preferences
 811            .per_worktree
 812            .insert(absolute_path_of_project_worktree.clone(), can_collect_data);
 813
 814        db::write_and_log(cx, move || {
 815            persistence::DB
 816                .save_accepted_data_collection(absolute_path_of_project_worktree, can_collect_data)
 817        });
 818    }
 819
 820    fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context<Self>) {
 821        self.data_collection_preferences.never_ask_again = true;
 822
 823        // persist choice
 824        db::write_and_log(cx, move || {
 825            KEY_VALUE_STORE.write_kvp(
 826                ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(),
 827                "true".to_string(),
 828            )
 829        });
 830    }
 831
 832    fn load_data_collection_preferences(cx: &mut Context<Self>) -> DataCollectionPreferences {
 833        if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() {
 834            db::write_and_log(cx, move || async move {
 835                KEY_VALUE_STORE
 836                    .delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into())
 837                    .await
 838                    .log_err();
 839
 840                persistence::DB.clear_all_zeta_preferences().await
 841            });
 842            return DataCollectionPreferences::default();
 843        }
 844
 845        let never_ask_again = KEY_VALUE_STORE
 846            .read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY)
 847            .log_err()
 848            .flatten()
 849            .map(|value| value == "true")
 850            .unwrap_or(false);
 851
 852        let preferences_per_project = persistence::DB
 853            .get_all_zeta_preferences()
 854            .log_err()
 855            .unwrap_or_else(HashMap::default);
 856
 857        DataCollectionPreferences {
 858            never_ask_again,
 859            per_worktree: preferences_per_project,
 860        }
 861    }
 862}
 863
 864#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 865struct DataCollectionPreferences {
 866    /// Set when a user clicks on "Never Ask Again", can never be unset.
 867    never_ask_again: bool,
 868    per_worktree: HashMap<PathBuf, bool>,
 869}
 870
 871fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 872    a.zip(b)
 873        .take_while(|(a, b)| a == b)
 874        .map(|(a, _)| a.len_utf8())
 875        .sum()
 876}
 877
 878fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 879    let mut input_outline = String::new();
 880
 881    writeln!(
 882        input_outline,
 883        "```{}",
 884        snapshot
 885            .file()
 886            .map_or(Cow::Borrowed("untitled"), |file| file
 887                .path()
 888                .to_string_lossy())
 889    )
 890    .unwrap();
 891
 892    if let Some(outline) = snapshot.outline(None) {
 893        let guess_size = outline.items.len() * 15;
 894        input_outline.reserve(guess_size);
 895        for item in outline.items.iter() {
 896            let spacing = " ".repeat(item.depth);
 897            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
 898        }
 899    }
 900
 901    writeln!(input_outline, "```").unwrap();
 902
 903    input_outline
 904}
 905
 906fn prompt_for_excerpt(
 907    snapshot: &BufferSnapshot,
 908    excerpt_range: &Range<usize>,
 909    offset: usize,
 910) -> String {
 911    let mut prompt_excerpt = String::new();
 912    writeln!(
 913        prompt_excerpt,
 914        "```{}",
 915        snapshot
 916            .file()
 917            .map_or(Cow::Borrowed("untitled"), |file| file
 918                .path()
 919                .to_string_lossy())
 920    )
 921    .unwrap();
 922
 923    if excerpt_range.start == 0 {
 924        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
 925    }
 926
 927    let point_range = excerpt_range.to_point(snapshot);
 928    if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
 929        let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
 930        for chunk in snapshot.text_for_range(extra_context_line_range) {
 931            prompt_excerpt.push_str(chunk);
 932        }
 933    }
 934    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 935    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
 936        prompt_excerpt.push_str(chunk);
 937    }
 938    prompt_excerpt.push_str(CURSOR_MARKER);
 939    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
 940        prompt_excerpt.push_str(chunk);
 941    }
 942    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 943
 944    if point_range.end.row < snapshot.max_point().row
 945        && !snapshot.is_line_blank(point_range.end.row + 1)
 946    {
 947        let extra_context_line_range = point_range.end
 948            ..Point::new(
 949                point_range.end.row + 1,
 950                snapshot.line_len(point_range.end.row + 1),
 951            );
 952        for chunk in snapshot.text_for_range(extra_context_line_range) {
 953            prompt_excerpt.push_str(chunk);
 954        }
 955    }
 956
 957    write!(prompt_excerpt, "\n```").unwrap();
 958    prompt_excerpt
 959}
 960
 961fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
 962    const CONTEXT_LINES: u32 = 32;
 963
 964    let mut context_lines_before = CONTEXT_LINES;
 965    let mut context_lines_after = CONTEXT_LINES;
 966    if point.row < CONTEXT_LINES {
 967        context_lines_after += CONTEXT_LINES - point.row;
 968    } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
 969        context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
 970    }
 971
 972    let excerpt_start_row = point.row.saturating_sub(context_lines_before);
 973    let excerpt_start = Point::new(excerpt_start_row, 0);
 974    let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
 975    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
 976    excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
 977}
 978
 979struct RegisteredBuffer {
 980    snapshot: BufferSnapshot,
 981    _subscriptions: [gpui::Subscription; 2],
 982}
 983
 984#[derive(Clone)]
 985enum Event {
 986    BufferChange {
 987        old_snapshot: BufferSnapshot,
 988        new_snapshot: BufferSnapshot,
 989        timestamp: Instant,
 990    },
 991}
 992
 993impl Event {
 994    fn to_prompt(&self) -> String {
 995        match self {
 996            Event::BufferChange {
 997                old_snapshot,
 998                new_snapshot,
 999                ..
1000            } => {
1001                let mut prompt = String::new();
1002
1003                let old_path = old_snapshot
1004                    .file()
1005                    .map(|f| f.path().as_ref())
1006                    .unwrap_or(Path::new("untitled"));
1007                let new_path = new_snapshot
1008                    .file()
1009                    .map(|f| f.path().as_ref())
1010                    .unwrap_or(Path::new("untitled"));
1011                if old_path != new_path {
1012                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1013                }
1014
1015                let diff =
1016                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1017                        .unified_diff()
1018                        .to_string();
1019                if !diff.is_empty() {
1020                    write!(
1021                        prompt,
1022                        "User edited {:?}:\n```diff\n{}\n```",
1023                        new_path, diff
1024                    )
1025                    .unwrap();
1026                }
1027
1028                prompt
1029            }
1030        }
1031    }
1032}
1033
1034#[derive(Debug, Clone)]
1035struct CurrentInlineCompletion {
1036    buffer_id: EntityId,
1037    completion: InlineCompletion,
1038}
1039
1040impl CurrentInlineCompletion {
1041    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1042        if self.buffer_id != old_completion.buffer_id {
1043            return true;
1044        }
1045
1046        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1047            return true;
1048        };
1049        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1050            return false;
1051        };
1052
1053        if old_edits.len() == 1 && new_edits.len() == 1 {
1054            let (old_range, old_text) = &old_edits[0];
1055            let (new_range, new_text) = &new_edits[0];
1056            new_range == old_range && new_text.starts_with(old_text)
1057        } else {
1058            true
1059        }
1060    }
1061}
1062
1063struct PendingCompletion {
1064    id: usize,
1065    _task: Task<()>,
1066}
1067
1068#[derive(Clone, Copy)]
1069pub enum DataCollectionChoice {
1070    NotAnswered,
1071    Enabled,
1072    Disabled,
1073}
1074
1075impl DataCollectionChoice {
1076    pub fn is_enabled(&self) -> bool {
1077        match self {
1078            Self::Enabled => true,
1079            Self::NotAnswered | Self::Disabled => false,
1080        }
1081    }
1082
1083    pub fn is_answered(&self) -> bool {
1084        match self {
1085            Self::Enabled | Self::Disabled => true,
1086            Self::NotAnswered => false,
1087        }
1088    }
1089
1090    pub fn toggle(&self) -> DataCollectionChoice {
1091        match self {
1092            Self::Enabled => Self::Disabled,
1093            Self::Disabled => Self::Enabled,
1094            Self::NotAnswered => Self::Enabled,
1095        }
1096    }
1097}
1098
1099pub struct ZetaInlineCompletionProvider {
1100    zeta: Entity<Zeta>,
1101    pending_completions: ArrayVec<PendingCompletion, 2>,
1102    next_pending_completion_id: usize,
1103    current_completion: Option<CurrentInlineCompletion>,
1104    data_collection: Option<ProviderDataCollection>,
1105}
1106
1107pub struct ProviderDataCollection {
1108    workspace: WeakEntity<Workspace>,
1109    worktree_root_path: PathBuf,
1110    choice: DataCollectionChoice,
1111}
1112
1113impl ProviderDataCollection {
1114    pub fn new(
1115        zeta: Entity<Zeta>,
1116        workspace: Option<Entity<Workspace>>,
1117        buffer: Option<Entity<Buffer>>,
1118        cx: &mut App,
1119    ) -> Option<ProviderDataCollection> {
1120        let workspace = workspace?;
1121
1122        let worktree_root_path = buffer?.update(cx, |buffer, cx| {
1123            let file = buffer.file()?;
1124
1125            if !file.is_local() || file.is_private() {
1126                return None;
1127            }
1128
1129            workspace.update(cx, |workspace, cx| {
1130                Some(
1131                    workspace
1132                        .absolute_path_of_worktree(file.worktree_id(cx), cx)?
1133                        .to_path_buf(),
1134                )
1135            })
1136        })?;
1137
1138        let choice = zeta.read(cx).data_collection_choice_at(&worktree_root_path);
1139
1140        Some(ProviderDataCollection {
1141            workspace: workspace.downgrade(),
1142            worktree_root_path,
1143            choice,
1144        })
1145    }
1146
1147    fn set_choice(&mut self, choice: DataCollectionChoice, zeta: &Entity<Zeta>, cx: &mut App) {
1148        self.choice = choice;
1149
1150        let worktree_root_path = self.worktree_root_path.clone();
1151
1152        zeta.update(cx, |zeta, cx| {
1153            zeta.update_data_collection_choice_for_worktree(
1154                worktree_root_path,
1155                choice.is_enabled(),
1156                cx,
1157            )
1158        });
1159    }
1160
1161    fn toggle_choice(&mut self, zeta: &Entity<Zeta>, cx: &mut App) {
1162        self.set_choice(self.choice.toggle(), zeta, cx);
1163    }
1164}
1165
1166impl ZetaInlineCompletionProvider {
1167    pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
1168
1169    pub fn new(zeta: Entity<Zeta>, data_collection: Option<ProviderDataCollection>) -> Self {
1170        Self {
1171            zeta,
1172            pending_completions: ArrayVec::new(),
1173            next_pending_completion_id: 0,
1174            current_completion: None,
1175            data_collection,
1176        }
1177    }
1178
1179    fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
1180        if let Some(data_collection) = self.data_collection.as_mut() {
1181            data_collection.set_choice(choice, &self.zeta, cx);
1182        }
1183    }
1184}
1185
1186impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
1187    fn name() -> &'static str {
1188        "zed-predict"
1189    }
1190
1191    fn display_name() -> &'static str {
1192        "Zed's Edit Predictions"
1193    }
1194
1195    fn show_completions_in_menu() -> bool {
1196        true
1197    }
1198
1199    fn show_completions_in_normal_mode() -> bool {
1200        true
1201    }
1202
1203    fn show_tab_accept_marker() -> bool {
1204        true
1205    }
1206
1207    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
1208        let Some(data_collection) = self.data_collection.as_ref() else {
1209            return DataCollectionState::Unknown;
1210        };
1211
1212        if data_collection.choice.is_enabled() {
1213            DataCollectionState::Enabled
1214        } else {
1215            DataCollectionState::Disabled
1216        }
1217    }
1218
1219    fn toggle_data_collection(&mut self, cx: &mut App) {
1220        if let Some(data_collection) = self.data_collection.as_mut() {
1221            data_collection.toggle_choice(&self.zeta, cx);
1222        }
1223    }
1224
1225    fn is_enabled(
1226        &self,
1227        buffer: &Entity<Buffer>,
1228        cursor_position: language::Anchor,
1229        cx: &App,
1230    ) -> bool {
1231        let buffer = buffer.read(cx);
1232        let file = buffer.file();
1233        let language = buffer.language_at(cursor_position);
1234        let settings = all_language_settings(file, cx);
1235        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1236    }
1237
1238    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1239        !self.zeta.read(cx).tos_accepted
1240    }
1241
1242    fn is_refreshing(&self) -> bool {
1243        !self.pending_completions.is_empty()
1244    }
1245
1246    fn refresh(
1247        &mut self,
1248        buffer: Entity<Buffer>,
1249        position: language::Anchor,
1250        debounce: bool,
1251        cx: &mut Context<Self>,
1252    ) {
1253        if !self.zeta.read(cx).tos_accepted {
1254            return;
1255        }
1256
1257        let pending_completion_id = self.next_pending_completion_id;
1258        self.next_pending_completion_id += 1;
1259        let can_collect_data = self
1260            .data_collection
1261            .as_ref()
1262            .map_or(false, |data_collection| data_collection.choice.is_enabled());
1263
1264        let task = cx.spawn(|this, mut cx| async move {
1265            if debounce {
1266                cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1267            }
1268
1269            let completion_request = this.update(&mut cx, |this, cx| {
1270                this.zeta.update(cx, |zeta, cx| {
1271                    zeta.request_completion(&buffer, position, can_collect_data, cx)
1272                })
1273            });
1274
1275            let completion = match completion_request {
1276                Ok(completion_request) => {
1277                    let completion_request = completion_request.await;
1278                    completion_request.map(|c| {
1279                        c.map(|completion| CurrentInlineCompletion {
1280                            buffer_id: buffer.entity_id(),
1281                            completion,
1282                        })
1283                    })
1284                }
1285                Err(error) => Err(error),
1286            };
1287            let Some(new_completion) = completion
1288                .context("edit prediction failed")
1289                .log_err()
1290                .flatten()
1291            else {
1292                return;
1293            };
1294
1295            this.update(&mut cx, |this, cx| {
1296                if this.pending_completions[0].id == pending_completion_id {
1297                    this.pending_completions.remove(0);
1298                } else {
1299                    this.pending_completions.clear();
1300                }
1301
1302                if let Some(old_completion) = this.current_completion.as_ref() {
1303                    let snapshot = buffer.read(cx).snapshot();
1304                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1305                        this.zeta.update(cx, |zeta, cx| {
1306                            zeta.completion_shown(&new_completion.completion, cx);
1307                        });
1308                        this.current_completion = Some(new_completion);
1309                    }
1310                } else {
1311                    this.zeta.update(cx, |zeta, cx| {
1312                        zeta.completion_shown(&new_completion.completion, cx);
1313                    });
1314                    this.current_completion = Some(new_completion);
1315                }
1316
1317                cx.notify();
1318            })
1319            .ok();
1320        });
1321
1322        // We always maintain at most two pending completions. When we already
1323        // have two, we replace the newest one.
1324        if self.pending_completions.len() <= 1 {
1325            self.pending_completions.push(PendingCompletion {
1326                id: pending_completion_id,
1327                _task: task,
1328            });
1329        } else if self.pending_completions.len() == 2 {
1330            self.pending_completions.pop();
1331            self.pending_completions.push(PendingCompletion {
1332                id: pending_completion_id,
1333                _task: task,
1334            });
1335        }
1336    }
1337
1338    fn cycle(
1339        &mut self,
1340        _buffer: Entity<Buffer>,
1341        _cursor_position: language::Anchor,
1342        _direction: inline_completion::Direction,
1343        _cx: &mut Context<Self>,
1344    ) {
1345        // Right now we don't support cycling.
1346    }
1347
1348    fn accept(&mut self, cx: &mut Context<Self>) {
1349        self.pending_completions.clear();
1350
1351        let Some(data_collection) = self.data_collection.as_mut() else {
1352            return;
1353        };
1354
1355        if data_collection.choice.is_answered()
1356            || self
1357                .zeta
1358                .read(cx)
1359                .data_collection_preferences
1360                .never_ask_again
1361        {
1362            return;
1363        }
1364
1365        struct ZetaDataCollectionNotification;
1366        let notification_id = NotificationId::unique::<ZetaDataCollectionNotification>();
1367
1368        const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection
1369
1370        let this = cx.entity();
1371        data_collection
1372            .workspace
1373            .update(cx, |workspace, cx| {
1374                workspace.show_notification(notification_id, cx, |cx| {
1375                    let zeta = self.zeta.clone();
1376
1377                    cx.new(move |_cx| {
1378                        let message =
1379                            "To allow Zed to suggest better edits, turn on data collection. You \
1380                            can turn off at any time via the status bar menu.";
1381                        MessageNotification::new(message)
1382                            .with_title("Per-Project Data Collection Program")
1383                            .show_close_button(false)
1384                            .with_click_message("Turn On")
1385                            .on_click({
1386                                let this = this.clone();
1387                                move |_window, cx| {
1388                                    this.update(cx, |this, cx| {
1389                                        this.set_data_collection_choice(
1390                                            DataCollectionChoice::Enabled,
1391                                            cx,
1392                                        )
1393                                    });
1394                                }
1395                            })
1396                            .with_secondary_click_message("Turn Off")
1397                            .on_secondary_click({
1398                                move |_window, cx| {
1399                                    this.update(cx, |this, cx| {
1400                                        this.set_data_collection_choice(
1401                                            DataCollectionChoice::Disabled,
1402                                            cx,
1403                                        )
1404                                    });
1405                                }
1406                            })
1407                            .with_tertiary_click_message("Never Ask Again")
1408                            .on_tertiary_click({
1409                                let zeta = zeta.clone();
1410                                move |_window, cx| {
1411                                    zeta.update(cx, |zeta, cx| {
1412                                        zeta.set_never_ask_again_for_data_collection(cx);
1413                                    });
1414                                }
1415                            })
1416                            .more_info_message("Learn More")
1417                            .more_info_url(DATA_COLLECTION_INFO_URL)
1418                    })
1419                });
1420            })
1421            .log_err();
1422    }
1423
1424    fn discard(&mut self, _cx: &mut Context<Self>) {
1425        self.pending_completions.clear();
1426        self.current_completion.take();
1427    }
1428
1429    fn suggest(
1430        &mut self,
1431        buffer: &Entity<Buffer>,
1432        cursor_position: language::Anchor,
1433        cx: &mut Context<Self>,
1434    ) -> Option<inline_completion::InlineCompletion> {
1435        let CurrentInlineCompletion {
1436            buffer_id,
1437            completion,
1438            ..
1439        } = self.current_completion.as_mut()?;
1440
1441        // Invalidate previous completion if it was generated for a different buffer.
1442        if *buffer_id != buffer.entity_id() {
1443            self.current_completion.take();
1444            return None;
1445        }
1446
1447        let buffer = buffer.read(cx);
1448        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1449            self.current_completion.take();
1450            return None;
1451        };
1452
1453        let cursor_row = cursor_position.to_point(buffer).row;
1454        let (closest_edit_ix, (closest_edit_range, _)) =
1455            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1456                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1457                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1458                cmp::min(distance_from_start, distance_from_end)
1459            })?;
1460
1461        let mut edit_start_ix = closest_edit_ix;
1462        for (range, _) in edits[..edit_start_ix].iter().rev() {
1463            let distance_from_closest_edit =
1464                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1465            if distance_from_closest_edit <= 1 {
1466                edit_start_ix -= 1;
1467            } else {
1468                break;
1469            }
1470        }
1471
1472        let mut edit_end_ix = closest_edit_ix + 1;
1473        for (range, _) in &edits[edit_end_ix..] {
1474            let distance_from_closest_edit =
1475                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1476            if distance_from_closest_edit <= 1 {
1477                edit_end_ix += 1;
1478            } else {
1479                break;
1480            }
1481        }
1482
1483        Some(inline_completion::InlineCompletion {
1484            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1485        })
1486    }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491    use client::test::FakeServer;
1492    use clock::FakeSystemClock;
1493    use gpui::TestAppContext;
1494    use http_client::FakeHttpClient;
1495    use indoc::indoc;
1496    use language_models::RefreshLlmTokenListener;
1497    use rpc::proto;
1498    use settings::SettingsStore;
1499
1500    use super::*;
1501
1502    #[gpui::test]
1503    fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1504        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1505        let completion = InlineCompletion {
1506            edits: cx
1507                .read(|cx| {
1508                    to_completion_edits(
1509                        [(2..5, "REM".to_string()), (9..11, "".to_string())],
1510                        &buffer,
1511                        cx,
1512                    )
1513                })
1514                .into(),
1515            path: Path::new("").into(),
1516            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1517            id: InlineCompletionId::new(),
1518            excerpt_range: 0..0,
1519            cursor_offset: 0,
1520            input_outline: "".into(),
1521            input_events: "".into(),
1522            input_excerpt: "".into(),
1523            output_excerpt: "".into(),
1524            request_sent_at: Instant::now(),
1525            response_received_at: Instant::now(),
1526        };
1527
1528        assert_eq!(
1529            cx.read(|cx| {
1530                from_completion_edits(
1531                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1532                    &buffer,
1533                    cx,
1534                )
1535            }),
1536            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1537        );
1538
1539        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1540        assert_eq!(
1541            cx.read(|cx| {
1542                from_completion_edits(
1543                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1544                    &buffer,
1545                    cx,
1546                )
1547            }),
1548            vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1549        );
1550
1551        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1552        assert_eq!(
1553            cx.read(|cx| {
1554                from_completion_edits(
1555                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1556                    &buffer,
1557                    cx,
1558                )
1559            }),
1560            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1561        );
1562
1563        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1564        assert_eq!(
1565            cx.read(|cx| {
1566                from_completion_edits(
1567                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1568                    &buffer,
1569                    cx,
1570                )
1571            }),
1572            vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1573        );
1574
1575        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1576        assert_eq!(
1577            cx.read(|cx| {
1578                from_completion_edits(
1579                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1580                    &buffer,
1581                    cx,
1582                )
1583            }),
1584            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1585        );
1586
1587        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1588        assert_eq!(
1589            cx.read(|cx| {
1590                from_completion_edits(
1591                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1592                    &buffer,
1593                    cx,
1594                )
1595            }),
1596            vec![(9..11, "".to_string())]
1597        );
1598
1599        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1600        assert_eq!(
1601            cx.read(|cx| {
1602                from_completion_edits(
1603                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1604                    &buffer,
1605                    cx,
1606                )
1607            }),
1608            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1609        );
1610
1611        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1612        assert_eq!(
1613            cx.read(|cx| {
1614                from_completion_edits(
1615                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1616                    &buffer,
1617                    cx,
1618                )
1619            }),
1620            vec![(4..4, "M".to_string())]
1621        );
1622
1623        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1624        assert_eq!(
1625            cx.read(|cx| completion.interpolate(&buffer.read(cx).snapshot())),
1626            None
1627        );
1628    }
1629
1630    #[gpui::test]
1631    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1632        cx.update(|cx| {
1633            let settings_store = SettingsStore::test(cx);
1634            cx.set_global(settings_store);
1635            client::init_settings(cx);
1636        });
1637
1638        let buffer_content = "lorem\n";
1639        let completion_response = indoc! {"
1640            ```animals.js
1641            <|start_of_file|>
1642            <|editable_region_start|>
1643            lorem
1644            ipsum
1645            <|editable_region_end|>
1646            ```"};
1647
1648        let http_client = FakeHttpClient::create(move |_| async move {
1649            Ok(http_client::Response::builder()
1650                .status(200)
1651                .body(
1652                    serde_json::to_string(&PredictEditsResponse {
1653                        output_excerpt: completion_response.to_string(),
1654                    })
1655                    .unwrap()
1656                    .into(),
1657                )
1658                .unwrap())
1659        });
1660
1661        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1662        cx.update(|cx| {
1663            RefreshLlmTokenListener::register(client.clone(), cx);
1664        });
1665        let server = FakeServer::for_client(42, &client, cx).await;
1666        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1667        let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1668
1669        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1670        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1671        let completion_task = zeta.update(cx, |zeta, cx| {
1672            zeta.request_completion(&buffer, cursor, false, cx)
1673        });
1674
1675        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1676        server.respond(
1677            token_request.receipt(),
1678            proto::GetLlmTokenResponse { token: "".into() },
1679        );
1680
1681        let completion = completion_task.await.unwrap().unwrap();
1682        buffer.update(cx, |buffer, cx| {
1683            buffer.edit(completion.edits.iter().cloned(), None, cx)
1684        });
1685        assert_eq!(
1686            buffer.read_with(cx, |buffer, _| buffer.text()),
1687            "lorem\nipsum"
1688        );
1689    }
1690
1691    fn to_completion_edits(
1692        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1693        buffer: &Entity<Buffer>,
1694        cx: &App,
1695    ) -> Vec<(Range<Anchor>, String)> {
1696        let buffer = buffer.read(cx);
1697        iterator
1698            .into_iter()
1699            .map(|(range, text)| {
1700                (
1701                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1702                    text,
1703                )
1704            })
1705            .collect()
1706    }
1707
1708    fn from_completion_edits(
1709        editor_edits: &[(Range<Anchor>, String)],
1710        buffer: &Entity<Buffer>,
1711        cx: &App,
1712    ) -> Vec<(Range<usize>, String)> {
1713        let buffer = buffer.read(cx);
1714        editor_edits
1715            .iter()
1716            .map(|(range, text)| {
1717                (
1718                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1719                    text.clone(),
1720                )
1721            })
1722            .collect()
1723    }
1724
1725    #[ctor::ctor]
1726    fn init_logger() {
1727        if std::env::var("RUST_LOG").is_ok() {
1728            env_logger::init();
1729        }
1730    }
1731}