zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_banner;
   6mod onboarding_modal;
   7mod onboarding_telemetry;
   8mod rate_completion_modal;
   9
  10pub(crate) use completion_diff_element::*;
  11use db::kvp::KEY_VALUE_STORE;
  12pub use init::*;
  13use inline_completion::DataCollectionState;
  14pub use license_detection::is_license_eligible_for_data_collection;
  15use license_detection::LICENSE_FILES_TO_CHECK;
  16pub use onboarding_banner::*;
  17pub use rate_completion_modal::*;
  18
  19use anyhow::{anyhow, Context as _, Result};
  20use arrayvec::ArrayVec;
  21use client::{Client, UserStore};
  22use collections::{HashMap, HashSet, VecDeque};
  23use feature_flags::FeatureFlagAppExt as _;
  24use futures::AsyncReadExt;
  25use gpui::{
  26    actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
  27};
  28use http_client::{HttpClient, Method};
  29use input_excerpt::excerpt_for_cursor_position;
  30use language::{
  31    Anchor, Buffer, BufferSnapshot, CharClassifier, CharKind, EditPreview, OffsetRangeExt,
  32    ToOffset, ToPoint,
  33};
  34use language_models::LlmApiToken;
  35use postage::watch;
  36use project::Project;
  37use settings::WorktreeId;
  38use std::{
  39    borrow::Cow,
  40    cmp,
  41    fmt::Write,
  42    future::Future,
  43    mem,
  44    ops::Range,
  45    path::Path,
  46    rc::Rc,
  47    sync::Arc,
  48    time::{Duration, Instant},
  49};
  50use telemetry_events::InlineCompletionRating;
  51use util::ResultExt;
  52use uuid::Uuid;
  53use worktree::Worktree;
  54use zed_llm_client::{PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  55
  56const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  57const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  58const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  59const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  60const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  61const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  62
  63const MAX_CONTEXT_TOKENS: usize = 150;
  64const MAX_REWRITE_TOKENS: usize = 350;
  65const MAX_EVENT_TOKENS: usize = 500;
  66
  67/// Maximum number of events to track.
  68const MAX_EVENT_COUNT: usize = 16;
  69
  70actions!(edit_prediction, [ClearHistory]);
  71
  72#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  73pub struct InlineCompletionId(Uuid);
  74
  75impl From<InlineCompletionId> for gpui::ElementId {
  76    fn from(value: InlineCompletionId) -> Self {
  77        gpui::ElementId::Uuid(value.0)
  78    }
  79}
  80
  81impl std::fmt::Display for InlineCompletionId {
  82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  83        write!(f, "{}", self.0)
  84    }
  85}
  86
  87#[derive(Clone)]
  88struct ZetaGlobal(Entity<Zeta>);
  89
  90impl Global for ZetaGlobal {}
  91
  92#[derive(Clone)]
  93pub struct InlineCompletion {
  94    id: InlineCompletionId,
  95    path: Arc<Path>,
  96    excerpt_range: Range<usize>,
  97    cursor_offset: usize,
  98    edits: Arc<[(Range<Anchor>, String)]>,
  99    snapshot: BufferSnapshot,
 100    edit_preview: EditPreview,
 101    input_outline: Arc<str>,
 102    input_events: Arc<str>,
 103    input_excerpt: Arc<str>,
 104    output_excerpt: Arc<str>,
 105    request_sent_at: Instant,
 106    response_received_at: Instant,
 107}
 108
 109impl InlineCompletion {
 110    fn latency(&self) -> Duration {
 111        self.response_received_at
 112            .duration_since(self.request_sent_at)
 113    }
 114
 115    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 116        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 117    }
 118}
 119
 120fn interpolate(
 121    old_snapshot: &BufferSnapshot,
 122    new_snapshot: &BufferSnapshot,
 123    current_edits: Arc<[(Range<Anchor>, String)]>,
 124) -> Option<Vec<(Range<Anchor>, String)>> {
 125    let mut edits = Vec::new();
 126
 127    let mut model_edits = current_edits.into_iter().peekable();
 128    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 129        while let Some((model_old_range, _)) = model_edits.peek() {
 130            let model_old_range = model_old_range.to_offset(old_snapshot);
 131            if model_old_range.end < user_edit.old.start {
 132                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 133                edits.push((model_old_range.clone(), model_new_text.clone()));
 134            } else {
 135                break;
 136            }
 137        }
 138
 139        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 140            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 141            if user_edit.old == model_old_offset_range {
 142                let user_new_text = new_snapshot
 143                    .text_for_range(user_edit.new.clone())
 144                    .collect::<String>();
 145
 146                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 147                    if !model_suffix.is_empty() {
 148                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 149                        edits.push((anchor..anchor, model_suffix.to_string()));
 150                    }
 151
 152                    model_edits.next();
 153                    continue;
 154                }
 155            }
 156        }
 157
 158        return None;
 159    }
 160
 161    edits.extend(model_edits.cloned());
 162
 163    if edits.is_empty() {
 164        None
 165    } else {
 166        Some(edits)
 167    }
 168}
 169
 170impl std::fmt::Debug for InlineCompletion {
 171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 172        f.debug_struct("InlineCompletion")
 173            .field("id", &self.id)
 174            .field("path", &self.path)
 175            .field("edits", &self.edits)
 176            .finish_non_exhaustive()
 177    }
 178}
 179
 180pub struct Zeta {
 181    client: Arc<Client>,
 182    events: VecDeque<Event>,
 183    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 184    shown_completions: VecDeque<InlineCompletion>,
 185    rated_completions: HashSet<InlineCompletionId>,
 186    data_collection_choice: Entity<DataCollectionChoice>,
 187    llm_token: LlmApiToken,
 188    _llm_token_subscription: Subscription,
 189    /// Whether the terms of service have been accepted.
 190    tos_accepted: bool,
 191    _user_store_subscription: Subscription,
 192    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 193}
 194
 195impl Zeta {
 196    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 197        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 198    }
 199
 200    pub fn register(
 201        worktree: Option<Entity<Worktree>>,
 202        client: Arc<Client>,
 203        user_store: Entity<UserStore>,
 204        cx: &mut App,
 205    ) -> Entity<Self> {
 206        let this = Self::global(cx).unwrap_or_else(|| {
 207            let entity = cx.new(|cx| Self::new(client, user_store, cx));
 208            cx.set_global(ZetaGlobal(entity.clone()));
 209            entity
 210        });
 211
 212        this.update(cx, move |this, cx| {
 213            if let Some(worktree) = worktree {
 214                worktree.update(cx, |worktree, cx| {
 215                    this.license_detection_watchers
 216                        .entry(worktree.id())
 217                        .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
 218                });
 219            }
 220        });
 221
 222        this
 223    }
 224
 225    pub fn clear_history(&mut self) {
 226        self.events.clear();
 227    }
 228
 229    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 230        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 231
 232        let data_collection_choice = Self::load_data_collection_choices();
 233        let data_collection_choice = cx.new(|_| data_collection_choice);
 234
 235        Self {
 236            client,
 237            events: VecDeque::new(),
 238            shown_completions: VecDeque::new(),
 239            rated_completions: HashSet::default(),
 240            registered_buffers: HashMap::default(),
 241            data_collection_choice,
 242            llm_token: LlmApiToken::default(),
 243            _llm_token_subscription: cx.subscribe(
 244                &refresh_llm_token_listener,
 245                |this, _listener, _event, cx| {
 246                    let client = this.client.clone();
 247                    let llm_token = this.llm_token.clone();
 248                    cx.spawn(|_this, _cx| async move {
 249                        llm_token.refresh(&client).await?;
 250                        anyhow::Ok(())
 251                    })
 252                    .detach_and_log_err(cx);
 253                },
 254            ),
 255            tos_accepted: user_store
 256                .read(cx)
 257                .current_user_has_accepted_terms()
 258                .unwrap_or(false),
 259            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 260                match event {
 261                    client::user::Event::PrivateUserInfoUpdated => {
 262                        this.tos_accepted = user_store
 263                            .read(cx)
 264                            .current_user_has_accepted_terms()
 265                            .unwrap_or(false);
 266                    }
 267                    _ => {}
 268                }
 269            }),
 270            license_detection_watchers: HashMap::default(),
 271        }
 272    }
 273
 274    fn push_event(&mut self, event: Event) {
 275        if let Some(Event::BufferChange {
 276            new_snapshot: last_new_snapshot,
 277            timestamp: last_timestamp,
 278            ..
 279        }) = self.events.back_mut()
 280        {
 281            // Coalesce edits for the same buffer when they happen one after the other.
 282            let Event::BufferChange {
 283                old_snapshot,
 284                new_snapshot,
 285                timestamp,
 286            } = &event;
 287
 288            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 289                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 290                && old_snapshot.version == last_new_snapshot.version
 291            {
 292                *last_new_snapshot = new_snapshot.clone();
 293                *last_timestamp = *timestamp;
 294                return;
 295            }
 296        }
 297
 298        self.events.push_back(event);
 299        if self.events.len() >= MAX_EVENT_COUNT {
 300            self.events.drain(..MAX_EVENT_COUNT / 2);
 301        }
 302    }
 303
 304    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 305        let buffer_id = buffer.entity_id();
 306        let weak_buffer = buffer.downgrade();
 307
 308        if let std::collections::hash_map::Entry::Vacant(entry) =
 309            self.registered_buffers.entry(buffer_id)
 310        {
 311            let snapshot = buffer.read(cx).snapshot();
 312
 313            entry.insert(RegisteredBuffer {
 314                snapshot,
 315                _subscriptions: [
 316                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 317                        this.handle_buffer_event(buffer, event, cx);
 318                    }),
 319                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 320                        this.registered_buffers.remove(&weak_buffer.entity_id());
 321                    }),
 322                ],
 323            });
 324        };
 325    }
 326
 327    fn handle_buffer_event(
 328        &mut self,
 329        buffer: Entity<Buffer>,
 330        event: &language::BufferEvent,
 331        cx: &mut Context<Self>,
 332    ) {
 333        if let language::BufferEvent::Edited = event {
 334            self.report_changes_for_buffer(&buffer, cx);
 335        }
 336    }
 337
 338    pub fn request_completion_impl<F, R>(
 339        &mut self,
 340        project: Option<&Entity<Project>>,
 341        buffer: &Entity<Buffer>,
 342        cursor: language::Anchor,
 343        can_collect_data: bool,
 344        cx: &mut Context<Self>,
 345        perform_predict_edits: F,
 346    ) -> Task<Result<Option<InlineCompletion>>>
 347    where
 348        F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsBody) -> R + 'static,
 349        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 350    {
 351        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 352        let diagnostic_groups = snapshot.diagnostic_groups(None);
 353        let cursor_point = cursor.to_point(&snapshot);
 354        let cursor_offset = cursor_point.to_offset(&snapshot);
 355        let events = self.events.clone();
 356        let path: Arc<Path> = snapshot
 357            .file()
 358            .map(|f| Arc::from(f.full_path(cx).as_path()))
 359            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 360
 361        let client = self.client.clone();
 362        let llm_token = self.llm_token.clone();
 363        let is_staff = cx.is_staff();
 364
 365        let buffer = buffer.clone();
 366
 367        let local_lsp_store =
 368            project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
 369        let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
 370            Some(
 371                diagnostic_groups
 372                    .into_iter()
 373                    .filter_map(|(language_server_id, diagnostic_group)| {
 374                        let language_server =
 375                            local_lsp_store.running_language_server_for_id(language_server_id)?;
 376
 377                        Some((
 378                            language_server.name(),
 379                            diagnostic_group.resolve::<usize>(&snapshot),
 380                        ))
 381                    })
 382                    .collect::<Vec<_>>(),
 383            )
 384        } else {
 385            None
 386        };
 387
 388        cx.spawn(|_, cx| async move {
 389            let request_sent_at = Instant::now();
 390
 391            struct BackgroundValues {
 392                input_events: String,
 393                input_excerpt: String,
 394                speculated_output: String,
 395                editable_range: Range<usize>,
 396                input_outline: String,
 397            }
 398
 399            let values = cx
 400                .background_executor()
 401                .spawn({
 402                    let snapshot = snapshot.clone();
 403                    let path = path.clone();
 404                    async move {
 405                        let path = path.to_string_lossy();
 406                        let input_excerpt = excerpt_for_cursor_position(
 407                            cursor_point,
 408                            &path,
 409                            &snapshot,
 410                            MAX_REWRITE_TOKENS,
 411                            MAX_CONTEXT_TOKENS,
 412                        );
 413                        let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
 414                        let input_outline = prompt_for_outline(&snapshot);
 415
 416                        anyhow::Ok(BackgroundValues {
 417                            input_events,
 418                            input_excerpt: input_excerpt.prompt,
 419                            speculated_output: input_excerpt.speculated_output,
 420                            editable_range: input_excerpt.editable_range.to_offset(&snapshot),
 421                            input_outline,
 422                        })
 423                    }
 424                })
 425                .await?;
 426
 427            log::debug!(
 428                "Events:\n{}\nExcerpt:\n{:?}",
 429                values.input_events,
 430                values.input_excerpt
 431            );
 432
 433            let body = PredictEditsBody {
 434                input_events: values.input_events.clone(),
 435                input_excerpt: values.input_excerpt.clone(),
 436                speculated_output: Some(values.speculated_output),
 437                outline: Some(values.input_outline.clone()),
 438                can_collect_data,
 439                diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
 440                    diagnostic_groups
 441                        .into_iter()
 442                        .map(|(name, diagnostic_group)| {
 443                            Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
 444                        })
 445                        .collect::<Result<Vec<_>>>()
 446                        .log_err()
 447                }),
 448            };
 449
 450            let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
 451
 452            log::debug!("completion response: {}", &response.output_excerpt);
 453
 454            Self::process_completion_response(
 455                response,
 456                buffer,
 457                &snapshot,
 458                values.editable_range,
 459                cursor_offset,
 460                path,
 461                values.input_outline,
 462                values.input_events,
 463                values.input_excerpt,
 464                request_sent_at,
 465                &cx,
 466            )
 467            .await
 468        })
 469    }
 470
 471    // Generates several example completions of various states to fill the Zeta completion modal
 472    #[cfg(any(test, feature = "test-support"))]
 473    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 474        use language::Point;
 475
 476        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 477            And maybe a short line
 478
 479            Then a few lines
 480
 481            and then another
 482            "#};
 483
 484        let project = None;
 485        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 486        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 487
 488        let completion_tasks = vec![
 489            self.fake_completion(
 490                project,
 491                &buffer,
 492                position,
 493                PredictEditsResponse {
 494                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 495                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 496a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 497[here's an edit]
 498And maybe a short line
 499Then a few lines
 500and then another
 501{EDITABLE_REGION_END_MARKER}
 502                        ", ),
 503                },
 504                cx,
 505            ),
 506            self.fake_completion(
 507                project,
 508                &buffer,
 509                position,
 510                PredictEditsResponse {
 511                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 512                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 513a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 514And maybe a short line
 515[and another edit]
 516Then a few lines
 517and then another
 518{EDITABLE_REGION_END_MARKER}
 519                        "#),
 520                },
 521                cx,
 522            ),
 523            self.fake_completion(
 524                project,
 525                &buffer,
 526                position,
 527                PredictEditsResponse {
 528                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 529                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 530a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 531And maybe a short line
 532
 533Then a few lines
 534
 535and then another
 536{EDITABLE_REGION_END_MARKER}
 537                        "#),
 538                },
 539                cx,
 540            ),
 541            self.fake_completion(
 542                project,
 543                &buffer,
 544                position,
 545                PredictEditsResponse {
 546                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 547                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 548a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 549And maybe a short line
 550
 551Then a few lines
 552
 553and then another
 554{EDITABLE_REGION_END_MARKER}
 555                        "#),
 556                },
 557                cx,
 558            ),
 559            self.fake_completion(
 560                project,
 561                &buffer,
 562                position,
 563                PredictEditsResponse {
 564                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 565                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 566a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 567And maybe a short line
 568Then a few lines
 569[a third completion]
 570and then another
 571{EDITABLE_REGION_END_MARKER}
 572                        "#),
 573                },
 574                cx,
 575            ),
 576            self.fake_completion(
 577                project,
 578                &buffer,
 579                position,
 580                PredictEditsResponse {
 581                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 582                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 583a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 584And maybe a short line
 585and then another
 586[fourth completion example]
 587{EDITABLE_REGION_END_MARKER}
 588                        "#),
 589                },
 590                cx,
 591            ),
 592            self.fake_completion(
 593                project,
 594                &buffer,
 595                position,
 596                PredictEditsResponse {
 597                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 598                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 599a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 600And maybe a short line
 601Then a few lines
 602and then another
 603[fifth and final completion]
 604{EDITABLE_REGION_END_MARKER}
 605                        "#),
 606                },
 607                cx,
 608            ),
 609        ];
 610
 611        cx.spawn(|zeta, mut cx| async move {
 612            for task in completion_tasks {
 613                task.await.unwrap();
 614            }
 615
 616            zeta.update(&mut cx, |zeta, _cx| {
 617                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 618                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 619            })
 620            .ok();
 621        })
 622    }
 623
 624    #[cfg(any(test, feature = "test-support"))]
 625    pub fn fake_completion(
 626        &mut self,
 627        project: Option<&Entity<Project>>,
 628        buffer: &Entity<Buffer>,
 629        position: language::Anchor,
 630        response: PredictEditsResponse,
 631        cx: &mut Context<Self>,
 632    ) -> Task<Result<Option<InlineCompletion>>> {
 633        use std::future::ready;
 634
 635        self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| {
 636            ready(Ok(response))
 637        })
 638    }
 639
 640    pub fn request_completion(
 641        &mut self,
 642        project: Option<&Entity<Project>>,
 643        buffer: &Entity<Buffer>,
 644        position: language::Anchor,
 645        can_collect_data: bool,
 646        cx: &mut Context<Self>,
 647    ) -> Task<Result<Option<InlineCompletion>>> {
 648        self.request_completion_impl(
 649            project,
 650            buffer,
 651            position,
 652            can_collect_data,
 653            cx,
 654            Self::perform_predict_edits,
 655        )
 656    }
 657
 658    fn perform_predict_edits(
 659        client: Arc<Client>,
 660        llm_token: LlmApiToken,
 661        _is_staff: bool,
 662        body: PredictEditsBody,
 663    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 664        async move {
 665            let http_client = client.http_client();
 666            let mut token = llm_token.acquire(&client).await?;
 667            let mut did_retry = false;
 668
 669            loop {
 670                let request_builder = http_client::Request::builder().method(Method::POST);
 671                let request_builder =
 672                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 673                        request_builder.uri(predict_edits_url)
 674                    } else {
 675                        request_builder.uri(
 676                            http_client
 677                                .build_zed_llm_url("/predict_edits/v2", &[])?
 678                                .as_ref(),
 679                        )
 680                    };
 681                let request = request_builder
 682                    .header("Content-Type", "application/json")
 683                    .header("Authorization", format!("Bearer {}", token))
 684                    .body(serde_json::to_string(&body)?.into())?;
 685
 686                let mut response = http_client.send(request).await?;
 687
 688                if response.status().is_success() {
 689                    let mut body = String::new();
 690                    response.body_mut().read_to_string(&mut body).await?;
 691                    return Ok(serde_json::from_str(&body)?);
 692                } else if !did_retry
 693                    && response
 694                        .headers()
 695                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 696                        .is_some()
 697                {
 698                    did_retry = true;
 699                    token = llm_token.refresh(&client).await?;
 700                } else {
 701                    let mut body = String::new();
 702                    response.body_mut().read_to_string(&mut body).await?;
 703                    return Err(anyhow!(
 704                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 705                        response.status(),
 706                        body
 707                    ));
 708                }
 709            }
 710        }
 711    }
 712
 713    #[allow(clippy::too_many_arguments)]
 714    fn process_completion_response(
 715        prediction_response: PredictEditsResponse,
 716        buffer: Entity<Buffer>,
 717        snapshot: &BufferSnapshot,
 718        editable_range: Range<usize>,
 719        cursor_offset: usize,
 720        path: Arc<Path>,
 721        input_outline: String,
 722        input_events: String,
 723        input_excerpt: String,
 724        request_sent_at: Instant,
 725        cx: &AsyncApp,
 726    ) -> Task<Result<Option<InlineCompletion>>> {
 727        let snapshot = snapshot.clone();
 728        let request_id = prediction_response.request_id;
 729        let output_excerpt = prediction_response.output_excerpt;
 730        cx.spawn(|cx| async move {
 731            let output_excerpt: Arc<str> = output_excerpt.into();
 732
 733            let edits: Arc<[(Range<Anchor>, String)]> = cx
 734                .background_executor()
 735                .spawn({
 736                    let output_excerpt = output_excerpt.clone();
 737                    let editable_range = editable_range.clone();
 738                    let snapshot = snapshot.clone();
 739                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 740                })
 741                .await?
 742                .into();
 743
 744            let Some((edits, snapshot, edit_preview)) = buffer.read_with(&cx, {
 745                let edits = edits.clone();
 746                |buffer, cx| {
 747                    let new_snapshot = buffer.snapshot();
 748                    let edits: Arc<[(Range<Anchor>, String)]> =
 749                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 750                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 751                }
 752            })?
 753            else {
 754                return anyhow::Ok(None);
 755            };
 756
 757            let edit_preview = edit_preview.await;
 758
 759            Ok(Some(InlineCompletion {
 760                id: InlineCompletionId(request_id),
 761                path,
 762                excerpt_range: editable_range,
 763                cursor_offset,
 764                edits,
 765                edit_preview,
 766                snapshot,
 767                input_outline: input_outline.into(),
 768                input_events: input_events.into(),
 769                input_excerpt: input_excerpt.into(),
 770                output_excerpt,
 771                request_sent_at,
 772                response_received_at: Instant::now(),
 773            }))
 774        })
 775    }
 776
 777    fn parse_edits(
 778        output_excerpt: Arc<str>,
 779        editable_range: Range<usize>,
 780        snapshot: &BufferSnapshot,
 781    ) -> Result<Vec<(Range<Anchor>, String)>> {
 782        let content = output_excerpt.replace(CURSOR_MARKER, "");
 783
 784        let start_markers = content
 785            .match_indices(EDITABLE_REGION_START_MARKER)
 786            .collect::<Vec<_>>();
 787        anyhow::ensure!(
 788            start_markers.len() == 1,
 789            "expected exactly one start marker, found {}",
 790            start_markers.len()
 791        );
 792
 793        let end_markers = content
 794            .match_indices(EDITABLE_REGION_END_MARKER)
 795            .collect::<Vec<_>>();
 796        anyhow::ensure!(
 797            end_markers.len() == 1,
 798            "expected exactly one end marker, found {}",
 799            end_markers.len()
 800        );
 801
 802        let sof_markers = content
 803            .match_indices(START_OF_FILE_MARKER)
 804            .collect::<Vec<_>>();
 805        anyhow::ensure!(
 806            sof_markers.len() <= 1,
 807            "expected at most one start-of-file marker, found {}",
 808            sof_markers.len()
 809        );
 810
 811        let codefence_start = start_markers[0].0;
 812        let content = &content[codefence_start..];
 813
 814        let newline_ix = content.find('\n').context("could not find newline")?;
 815        let content = &content[newline_ix + 1..];
 816
 817        let codefence_end = content
 818            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 819            .context("could not find end marker")?;
 820        let new_text = &content[..codefence_end];
 821
 822        let old_text = snapshot
 823            .text_for_range(editable_range.clone())
 824            .collect::<String>();
 825
 826        Ok(Self::compute_edits(
 827            old_text,
 828            new_text,
 829            editable_range.start,
 830            &snapshot,
 831        ))
 832    }
 833
 834    pub fn compute_edits(
 835        old_text: String,
 836        new_text: &str,
 837        offset: usize,
 838        snapshot: &BufferSnapshot,
 839    ) -> Vec<(Range<Anchor>, String)> {
 840        fn tokenize(text: &str) -> Vec<&str> {
 841            let classifier = CharClassifier::new(None).for_completion(true);
 842            let mut chars = text.chars().peekable();
 843            let mut prev_ch = chars.peek().copied();
 844            let mut tokens = Vec::new();
 845            let mut start = 0;
 846            let mut end = 0;
 847            while let Some(ch) = chars.next() {
 848                let prev_kind = prev_ch.map(|ch| classifier.kind(ch));
 849                let kind = classifier.kind(ch);
 850                if Some(kind) != prev_kind || (kind == CharKind::Punctuation && Some(ch) != prev_ch)
 851                {
 852                    tokens.push(&text[start..end]);
 853                    start = end;
 854                }
 855                end += ch.len_utf8();
 856                prev_ch = Some(ch);
 857            }
 858            tokens.push(&text[start..end]);
 859            tokens
 860        }
 861
 862        let old_tokens = tokenize(&old_text);
 863        let new_tokens = tokenize(new_text);
 864
 865        let diff = similar::TextDiffConfig::default()
 866            .algorithm(similar::Algorithm::Patience)
 867            .diff_slices(&old_tokens, &new_tokens);
 868        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 869        let mut old_start = offset;
 870        for change in diff.iter_all_changes() {
 871            let value = change.value();
 872            match change.tag() {
 873                similar::ChangeTag::Equal => {
 874                    old_start += value.len();
 875                }
 876                similar::ChangeTag::Delete => {
 877                    let old_end = old_start + value.len();
 878                    if let Some((last_old_range, _)) = edits.last_mut() {
 879                        if last_old_range.end == old_start {
 880                            last_old_range.end = old_end;
 881                        } else {
 882                            edits.push((old_start..old_end, String::new()));
 883                        }
 884                    } else {
 885                        edits.push((old_start..old_end, String::new()));
 886                    }
 887                    old_start = old_end;
 888                }
 889                similar::ChangeTag::Insert => {
 890                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 891                        if last_old_range.end == old_start {
 892                            last_new_text.push_str(value);
 893                        } else {
 894                            edits.push((old_start..old_start, value.into()));
 895                        }
 896                    } else {
 897                        edits.push((old_start..old_start, value.into()));
 898                    }
 899                }
 900            }
 901        }
 902
 903        edits
 904            .into_iter()
 905            .map(|(mut old_range, new_text)| {
 906                let prefix_len = common_prefix(
 907                    snapshot.chars_for_range(old_range.clone()),
 908                    new_text.chars(),
 909                );
 910                old_range.start += prefix_len;
 911                let suffix_len = common_prefix(
 912                    snapshot.reversed_chars_for_range(old_range.clone()),
 913                    new_text[prefix_len..].chars().rev(),
 914                );
 915                old_range.end = old_range.end.saturating_sub(suffix_len);
 916
 917                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 918                let range = if old_range.is_empty() {
 919                    let anchor = snapshot.anchor_after(old_range.start);
 920                    anchor..anchor
 921                } else {
 922                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 923                };
 924                (range, new_text)
 925            })
 926            .collect()
 927    }
 928
 929    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 930        self.rated_completions.contains(&completion_id)
 931    }
 932
 933    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 934        self.shown_completions.push_front(completion.clone());
 935        if self.shown_completions.len() > 50 {
 936            let completion = self.shown_completions.pop_back().unwrap();
 937            self.rated_completions.remove(&completion.id);
 938        }
 939        cx.notify();
 940    }
 941
 942    pub fn rate_completion(
 943        &mut self,
 944        completion: &InlineCompletion,
 945        rating: InlineCompletionRating,
 946        feedback: String,
 947        cx: &mut Context<Self>,
 948    ) {
 949        self.rated_completions.insert(completion.id);
 950        telemetry::event!(
 951            "Edit Prediction Rated",
 952            rating,
 953            input_events = completion.input_events,
 954            input_excerpt = completion.input_excerpt,
 955            input_outline = completion.input_outline,
 956            output_excerpt = completion.output_excerpt,
 957            feedback
 958        );
 959        self.client.telemetry().flush_events();
 960        cx.notify();
 961    }
 962
 963    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 964        self.shown_completions.iter()
 965    }
 966
 967    pub fn shown_completions_len(&self) -> usize {
 968        self.shown_completions.len()
 969    }
 970
 971    fn report_changes_for_buffer(
 972        &mut self,
 973        buffer: &Entity<Buffer>,
 974        cx: &mut Context<Self>,
 975    ) -> BufferSnapshot {
 976        self.register_buffer(buffer, cx);
 977
 978        let registered_buffer = self
 979            .registered_buffers
 980            .get_mut(&buffer.entity_id())
 981            .unwrap();
 982        let new_snapshot = buffer.read(cx).snapshot();
 983
 984        if new_snapshot.version != registered_buffer.snapshot.version {
 985            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 986            self.push_event(Event::BufferChange {
 987                old_snapshot,
 988                new_snapshot: new_snapshot.clone(),
 989                timestamp: Instant::now(),
 990            });
 991        }
 992
 993        new_snapshot
 994    }
 995
 996    fn load_data_collection_choices() -> DataCollectionChoice {
 997        let choice = KEY_VALUE_STORE
 998            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
 999            .log_err()
1000            .flatten();
1001
1002        match choice.as_deref() {
1003            Some("true") => DataCollectionChoice::Enabled,
1004            Some("false") => DataCollectionChoice::Disabled,
1005            Some(_) => {
1006                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1007                DataCollectionChoice::NotAnswered
1008            }
1009            None => DataCollectionChoice::NotAnswered,
1010        }
1011    }
1012}
1013
1014struct LicenseDetectionWatcher {
1015    is_open_source_rx: watch::Receiver<bool>,
1016    _is_open_source_task: Task<()>,
1017}
1018
1019impl LicenseDetectionWatcher {
1020    pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1021        let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1022
1023        // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1024        let task = if worktree.abs_path().is_file() {
1025            Task::ready(())
1026        } else {
1027            let loaded_files = LICENSE_FILES_TO_CHECK
1028                .iter()
1029                .map(Path::new)
1030                .map(|file| worktree.load_file(file, cx))
1031                .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1032
1033            cx.background_executor().spawn(async move {
1034                for loaded_file in loaded_files.into_iter() {
1035                    let Ok(loaded_file) = loaded_file.await else {
1036                        continue;
1037                    };
1038
1039                    let path = &loaded_file.file.path;
1040                    if is_license_eligible_for_data_collection(&loaded_file.text) {
1041                        log::info!("detected '{path:?}' as open source license");
1042                        *is_open_source_tx.borrow_mut() = true;
1043                    } else {
1044                        log::info!("didn't detect '{path:?}' as open source license");
1045                    }
1046
1047                    // stop on the first license that successfully read
1048                    return;
1049                }
1050
1051                log::debug!("didn't find a license file to check, assuming closed source");
1052            })
1053        };
1054
1055        Self {
1056            is_open_source_rx,
1057            _is_open_source_task: task,
1058        }
1059    }
1060
1061    /// Answers false until we find out it's open source
1062    pub fn is_project_open_source(&self) -> bool {
1063        *self.is_open_source_rx.borrow()
1064    }
1065}
1066
1067fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1068    a.zip(b)
1069        .take_while(|(a, b)| a == b)
1070        .map(|(a, _)| a.len_utf8())
1071        .sum()
1072}
1073
1074fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1075    let mut input_outline = String::new();
1076
1077    writeln!(
1078        input_outline,
1079        "```{}",
1080        snapshot
1081            .file()
1082            .map_or(Cow::Borrowed("untitled"), |file| file
1083                .path()
1084                .to_string_lossy())
1085    )
1086    .unwrap();
1087
1088    if let Some(outline) = snapshot.outline(None) {
1089        for item in &outline.items {
1090            let spacing = " ".repeat(item.depth);
1091            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1092        }
1093    }
1094
1095    writeln!(input_outline, "```").unwrap();
1096
1097    input_outline
1098}
1099
1100fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1101    let mut result = String::new();
1102    for event in events.iter().rev() {
1103        let event_string = event.to_prompt();
1104        let event_tokens = tokens_for_bytes(event_string.len());
1105        if event_tokens > remaining_tokens {
1106            break;
1107        }
1108
1109        if !result.is_empty() {
1110            result.insert_str(0, "\n\n");
1111        }
1112        result.insert_str(0, &event_string);
1113        remaining_tokens -= event_tokens;
1114    }
1115    result
1116}
1117
1118struct RegisteredBuffer {
1119    snapshot: BufferSnapshot,
1120    _subscriptions: [gpui::Subscription; 2],
1121}
1122
1123#[derive(Clone)]
1124enum Event {
1125    BufferChange {
1126        old_snapshot: BufferSnapshot,
1127        new_snapshot: BufferSnapshot,
1128        timestamp: Instant,
1129    },
1130}
1131
1132impl Event {
1133    fn to_prompt(&self) -> String {
1134        match self {
1135            Event::BufferChange {
1136                old_snapshot,
1137                new_snapshot,
1138                ..
1139            } => {
1140                let mut prompt = String::new();
1141
1142                let old_path = old_snapshot
1143                    .file()
1144                    .map(|f| f.path().as_ref())
1145                    .unwrap_or(Path::new("untitled"));
1146                let new_path = new_snapshot
1147                    .file()
1148                    .map(|f| f.path().as_ref())
1149                    .unwrap_or(Path::new("untitled"));
1150                if old_path != new_path {
1151                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1152                }
1153
1154                let diff =
1155                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
1156                        .unified_diff()
1157                        .to_string();
1158                if !diff.is_empty() {
1159                    write!(
1160                        prompt,
1161                        "User edited {:?}:\n```diff\n{}\n```",
1162                        new_path, diff
1163                    )
1164                    .unwrap();
1165                }
1166
1167                prompt
1168            }
1169        }
1170    }
1171}
1172
1173#[derive(Debug, Clone)]
1174struct CurrentInlineCompletion {
1175    buffer_id: EntityId,
1176    completion: InlineCompletion,
1177}
1178
1179impl CurrentInlineCompletion {
1180    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1181        if self.buffer_id != old_completion.buffer_id {
1182            return true;
1183        }
1184
1185        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1186            return true;
1187        };
1188        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1189            return false;
1190        };
1191
1192        if old_edits.len() == 1 && new_edits.len() == 1 {
1193            let (old_range, old_text) = &old_edits[0];
1194            let (new_range, new_text) = &new_edits[0];
1195            new_range == old_range && new_text.starts_with(old_text)
1196        } else {
1197            true
1198        }
1199    }
1200}
1201
1202struct PendingCompletion {
1203    id: usize,
1204    _task: Task<()>,
1205}
1206
1207#[derive(Debug, Clone, Copy)]
1208pub enum DataCollectionChoice {
1209    NotAnswered,
1210    Enabled,
1211    Disabled,
1212}
1213
1214impl DataCollectionChoice {
1215    pub fn is_enabled(self) -> bool {
1216        match self {
1217            Self::Enabled => true,
1218            Self::NotAnswered | Self::Disabled => false,
1219        }
1220    }
1221
1222    pub fn is_answered(self) -> bool {
1223        match self {
1224            Self::Enabled | Self::Disabled => true,
1225            Self::NotAnswered => false,
1226        }
1227    }
1228
1229    pub fn toggle(&self) -> DataCollectionChoice {
1230        match self {
1231            Self::Enabled => Self::Disabled,
1232            Self::Disabled => Self::Enabled,
1233            Self::NotAnswered => Self::Enabled,
1234        }
1235    }
1236}
1237
1238impl From<bool> for DataCollectionChoice {
1239    fn from(value: bool) -> Self {
1240        match value {
1241            true => DataCollectionChoice::Enabled,
1242            false => DataCollectionChoice::Disabled,
1243        }
1244    }
1245}
1246
1247pub struct ProviderDataCollection {
1248    /// When set to None, data collection is not possible in the provider buffer
1249    choice: Option<Entity<DataCollectionChoice>>,
1250    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1251}
1252
1253impl ProviderDataCollection {
1254    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1255        let choice_and_watcher = buffer.and_then(|buffer| {
1256            let file = buffer.read(cx).file()?;
1257
1258            if !file.is_local() || file.is_private() {
1259                return None;
1260            }
1261
1262            let zeta = zeta.read(cx);
1263            let choice = zeta.data_collection_choice.clone();
1264
1265            let license_detection_watcher = zeta
1266                .license_detection_watchers
1267                .get(&file.worktree_id(cx))
1268                .cloned()?;
1269
1270            Some((choice, license_detection_watcher))
1271        });
1272
1273        if let Some((choice, watcher)) = choice_and_watcher {
1274            ProviderDataCollection {
1275                choice: Some(choice),
1276                license_detection_watcher: Some(watcher),
1277            }
1278        } else {
1279            ProviderDataCollection {
1280                choice: None,
1281                license_detection_watcher: None,
1282            }
1283        }
1284    }
1285
1286    pub fn can_collect_data(&self, cx: &App) -> bool {
1287        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1288    }
1289
1290    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1291        self.choice
1292            .as_ref()
1293            .is_some_and(|choice| choice.read(cx).is_enabled())
1294    }
1295
1296    fn is_project_open_source(&self) -> bool {
1297        self.license_detection_watcher
1298            .as_ref()
1299            .is_some_and(|watcher| watcher.is_project_open_source())
1300    }
1301
1302    pub fn toggle(&mut self, cx: &mut App) {
1303        if let Some(choice) = self.choice.as_mut() {
1304            let new_choice = choice.update(cx, |choice, _cx| {
1305                let new_choice = choice.toggle();
1306                *choice = new_choice;
1307                new_choice
1308            });
1309
1310            db::write_and_log(cx, move || {
1311                KEY_VALUE_STORE.write_kvp(
1312                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1313                    new_choice.is_enabled().to_string(),
1314                )
1315            });
1316        }
1317    }
1318}
1319
1320pub struct ZetaInlineCompletionProvider {
1321    zeta: Entity<Zeta>,
1322    pending_completions: ArrayVec<PendingCompletion, 2>,
1323    next_pending_completion_id: usize,
1324    current_completion: Option<CurrentInlineCompletion>,
1325    /// None if this is entirely disabled for this provider
1326    provider_data_collection: ProviderDataCollection,
1327    last_request_timestamp: Instant,
1328}
1329
1330impl ZetaInlineCompletionProvider {
1331    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1332
1333    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1334        Self {
1335            zeta,
1336            pending_completions: ArrayVec::new(),
1337            next_pending_completion_id: 0,
1338            current_completion: None,
1339            provider_data_collection,
1340            last_request_timestamp: Instant::now(),
1341        }
1342    }
1343}
1344
1345impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1346    fn name() -> &'static str {
1347        "zed-predict"
1348    }
1349
1350    fn display_name() -> &'static str {
1351        "Zed's Edit Predictions"
1352    }
1353
1354    fn show_completions_in_menu() -> bool {
1355        true
1356    }
1357
1358    fn show_tab_accept_marker() -> bool {
1359        true
1360    }
1361
1362    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1363        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1364
1365        if self.provider_data_collection.is_data_collection_enabled(cx) {
1366            DataCollectionState::Enabled {
1367                is_project_open_source,
1368            }
1369        } else {
1370            DataCollectionState::Disabled {
1371                is_project_open_source,
1372            }
1373        }
1374    }
1375
1376    fn toggle_data_collection(&mut self, cx: &mut App) {
1377        self.provider_data_collection.toggle(cx);
1378    }
1379
1380    fn is_enabled(
1381        &self,
1382        _buffer: &Entity<Buffer>,
1383        _cursor_position: language::Anchor,
1384        _cx: &App,
1385    ) -> bool {
1386        true
1387    }
1388
1389    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1390        !self.zeta.read(cx).tos_accepted
1391    }
1392
1393    fn is_refreshing(&self) -> bool {
1394        !self.pending_completions.is_empty()
1395    }
1396
1397    fn refresh(
1398        &mut self,
1399        project: Option<Entity<Project>>,
1400        buffer: Entity<Buffer>,
1401        position: language::Anchor,
1402        _debounce: bool,
1403        cx: &mut Context<Self>,
1404    ) {
1405        if !self.zeta.read(cx).tos_accepted {
1406            return;
1407        }
1408
1409        if let Some(current_completion) = self.current_completion.as_ref() {
1410            let snapshot = buffer.read(cx).snapshot();
1411            if current_completion
1412                .completion
1413                .interpolate(&snapshot)
1414                .is_some()
1415            {
1416                return;
1417            }
1418        }
1419
1420        let pending_completion_id = self.next_pending_completion_id;
1421        self.next_pending_completion_id += 1;
1422        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1423        let last_request_timestamp = self.last_request_timestamp;
1424
1425        let task = cx.spawn(|this, mut cx| async move {
1426            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1427                .checked_duration_since(Instant::now())
1428            {
1429                cx.background_executor().timer(timeout).await;
1430            }
1431
1432            let completion_request = this.update(&mut cx, |this, cx| {
1433                this.last_request_timestamp = Instant::now();
1434                this.zeta.update(cx, |zeta, cx| {
1435                    zeta.request_completion(
1436                        project.as_ref(),
1437                        &buffer,
1438                        position,
1439                        can_collect_data,
1440                        cx,
1441                    )
1442                })
1443            });
1444
1445            let completion = match completion_request {
1446                Ok(completion_request) => {
1447                    let completion_request = completion_request.await;
1448                    completion_request.map(|c| {
1449                        c.map(|completion| CurrentInlineCompletion {
1450                            buffer_id: buffer.entity_id(),
1451                            completion,
1452                        })
1453                    })
1454                }
1455                Err(error) => Err(error),
1456            };
1457            let Some(new_completion) = completion
1458                .context("edit prediction failed")
1459                .log_err()
1460                .flatten()
1461            else {
1462                this.update(&mut cx, |this, cx| {
1463                    if this.pending_completions[0].id == pending_completion_id {
1464                        this.pending_completions.remove(0);
1465                    } else {
1466                        this.pending_completions.clear();
1467                    }
1468
1469                    cx.notify();
1470                })
1471                .ok();
1472                return;
1473            };
1474
1475            this.update(&mut cx, |this, cx| {
1476                if this.pending_completions[0].id == pending_completion_id {
1477                    this.pending_completions.remove(0);
1478                } else {
1479                    this.pending_completions.clear();
1480                }
1481
1482                if let Some(old_completion) = this.current_completion.as_ref() {
1483                    let snapshot = buffer.read(cx).snapshot();
1484                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1485                        this.zeta.update(cx, |zeta, cx| {
1486                            zeta.completion_shown(&new_completion.completion, cx);
1487                        });
1488                        this.current_completion = Some(new_completion);
1489                    }
1490                } else {
1491                    this.zeta.update(cx, |zeta, cx| {
1492                        zeta.completion_shown(&new_completion.completion, cx);
1493                    });
1494                    this.current_completion = Some(new_completion);
1495                }
1496
1497                cx.notify();
1498            })
1499            .ok();
1500        });
1501
1502        // We always maintain at most two pending completions. When we already
1503        // have two, we replace the newest one.
1504        if self.pending_completions.len() <= 1 {
1505            self.pending_completions.push(PendingCompletion {
1506                id: pending_completion_id,
1507                _task: task,
1508            });
1509        } else if self.pending_completions.len() == 2 {
1510            self.pending_completions.pop();
1511            self.pending_completions.push(PendingCompletion {
1512                id: pending_completion_id,
1513                _task: task,
1514            });
1515        }
1516    }
1517
1518    fn cycle(
1519        &mut self,
1520        _buffer: Entity<Buffer>,
1521        _cursor_position: language::Anchor,
1522        _direction: inline_completion::Direction,
1523        _cx: &mut Context<Self>,
1524    ) {
1525        // Right now we don't support cycling.
1526    }
1527
1528    fn accept(&mut self, _cx: &mut Context<Self>) {
1529        self.pending_completions.clear();
1530    }
1531
1532    fn discard(&mut self, _cx: &mut Context<Self>) {
1533        self.pending_completions.clear();
1534        self.current_completion.take();
1535    }
1536
1537    fn suggest(
1538        &mut self,
1539        buffer: &Entity<Buffer>,
1540        cursor_position: language::Anchor,
1541        cx: &mut Context<Self>,
1542    ) -> Option<inline_completion::InlineCompletion> {
1543        let CurrentInlineCompletion {
1544            buffer_id,
1545            completion,
1546            ..
1547        } = self.current_completion.as_mut()?;
1548
1549        // Invalidate previous completion if it was generated for a different buffer.
1550        if *buffer_id != buffer.entity_id() {
1551            self.current_completion.take();
1552            return None;
1553        }
1554
1555        let buffer = buffer.read(cx);
1556        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1557            self.current_completion.take();
1558            return None;
1559        };
1560
1561        let cursor_row = cursor_position.to_point(buffer).row;
1562        let (closest_edit_ix, (closest_edit_range, _)) =
1563            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1564                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1565                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1566                cmp::min(distance_from_start, distance_from_end)
1567            })?;
1568
1569        let mut edit_start_ix = closest_edit_ix;
1570        for (range, _) in edits[..edit_start_ix].iter().rev() {
1571            let distance_from_closest_edit =
1572                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1573            if distance_from_closest_edit <= 1 {
1574                edit_start_ix -= 1;
1575            } else {
1576                break;
1577            }
1578        }
1579
1580        let mut edit_end_ix = closest_edit_ix + 1;
1581        for (range, _) in &edits[edit_end_ix..] {
1582            let distance_from_closest_edit =
1583                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1584            if distance_from_closest_edit <= 1 {
1585                edit_end_ix += 1;
1586            } else {
1587                break;
1588            }
1589        }
1590
1591        Some(inline_completion::InlineCompletion {
1592            id: Some(completion.id.to_string().into()),
1593            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1594            edit_preview: Some(completion.edit_preview.clone()),
1595        })
1596    }
1597}
1598
1599fn tokens_for_bytes(bytes: usize) -> usize {
1600    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1601    /// intentionally low to err on the side of underestimating limits.
1602    const BYTES_PER_TOKEN_GUESS: usize = 3;
1603    bytes / BYTES_PER_TOKEN_GUESS
1604}
1605
1606#[cfg(test)]
1607mod tests {
1608    use client::test::FakeServer;
1609    use clock::FakeSystemClock;
1610    use gpui::TestAppContext;
1611    use http_client::FakeHttpClient;
1612    use indoc::indoc;
1613    use language::Point;
1614    use language_models::RefreshLlmTokenListener;
1615    use rpc::proto;
1616    use settings::SettingsStore;
1617
1618    use super::*;
1619
1620    #[gpui::test]
1621    async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1622        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1623        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1624            to_completion_edits(
1625                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1626                &buffer,
1627                cx,
1628            )
1629            .into()
1630        });
1631
1632        let edit_preview = cx
1633            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1634            .await;
1635
1636        let completion = InlineCompletion {
1637            edits,
1638            edit_preview,
1639            path: Path::new("").into(),
1640            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1641            id: InlineCompletionId(Uuid::new_v4()),
1642            excerpt_range: 0..0,
1643            cursor_offset: 0,
1644            input_outline: "".into(),
1645            input_events: "".into(),
1646            input_excerpt: "".into(),
1647            output_excerpt: "".into(),
1648            request_sent_at: Instant::now(),
1649            response_received_at: Instant::now(),
1650        };
1651
1652        cx.update(|cx| {
1653            assert_eq!(
1654                from_completion_edits(
1655                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1656                    &buffer,
1657                    cx
1658                ),
1659                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1660            );
1661
1662            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1663            assert_eq!(
1664                from_completion_edits(
1665                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1666                    &buffer,
1667                    cx
1668                ),
1669                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1670            );
1671
1672            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1673            assert_eq!(
1674                from_completion_edits(
1675                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1676                    &buffer,
1677                    cx
1678                ),
1679                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1680            );
1681
1682            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1683            assert_eq!(
1684                from_completion_edits(
1685                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1686                    &buffer,
1687                    cx
1688                ),
1689                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1690            );
1691
1692            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1693            assert_eq!(
1694                from_completion_edits(
1695                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1696                    &buffer,
1697                    cx
1698                ),
1699                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1700            );
1701
1702            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1703            assert_eq!(
1704                from_completion_edits(
1705                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1706                    &buffer,
1707                    cx
1708                ),
1709                vec![(9..11, "".to_string())]
1710            );
1711
1712            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1713            assert_eq!(
1714                from_completion_edits(
1715                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1716                    &buffer,
1717                    cx
1718                ),
1719                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1720            );
1721
1722            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1723            assert_eq!(
1724                from_completion_edits(
1725                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1726                    &buffer,
1727                    cx
1728                ),
1729                vec![(4..4, "M".to_string())]
1730            );
1731
1732            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1733            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1734        })
1735    }
1736
1737    #[gpui::test]
1738    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1739        cx.update(|cx| {
1740            let settings_store = SettingsStore::test(cx);
1741            cx.set_global(settings_store);
1742            client::init_settings(cx);
1743        });
1744
1745        let edits = edits_for_prediction(
1746            indoc! {"
1747                fn main() {
1748                    let word_1 = \"lorem\";
1749                    let range = word.len()..word.len();
1750                }
1751            "},
1752            indoc! {"
1753                <|editable_region_start|>
1754                fn main() {
1755                    let word_1 = \"lorem\";
1756                    let range = word_1.len()..word_1.len();
1757                }
1758
1759                <|editable_region_end|>
1760            "},
1761            cx,
1762        )
1763        .await;
1764        assert_eq!(
1765            edits,
1766            [
1767                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1768                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1769            ]
1770        );
1771
1772        let edits = edits_for_prediction(
1773            indoc! {"
1774                fn main() {
1775                    let story = \"the quick\"
1776                }
1777            "},
1778            indoc! {"
1779                <|editable_region_start|>
1780                fn main() {
1781                    let story = \"the quick brown fox jumps over the lazy dog\";
1782                }
1783
1784                <|editable_region_end|>
1785            "},
1786            cx,
1787        )
1788        .await;
1789        assert_eq!(
1790            edits,
1791            [
1792                (
1793                    Point::new(1, 26)..Point::new(1, 26),
1794                    " brown fox jumps over the lazy dog".to_string()
1795                ),
1796                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1797            ]
1798        );
1799    }
1800
1801    #[gpui::test]
1802    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1803        cx.update(|cx| {
1804            let settings_store = SettingsStore::test(cx);
1805            cx.set_global(settings_store);
1806            client::init_settings(cx);
1807        });
1808
1809        let buffer_content = "lorem\n";
1810        let completion_response = indoc! {"
1811            ```animals.js
1812            <|start_of_file|>
1813            <|editable_region_start|>
1814            lorem
1815            ipsum
1816            <|editable_region_end|>
1817            ```"};
1818
1819        let http_client = FakeHttpClient::create(move |_| async move {
1820            Ok(http_client::Response::builder()
1821                .status(200)
1822                .body(
1823                    serde_json::to_string(&PredictEditsResponse {
1824                        request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1825                            .unwrap(),
1826                        output_excerpt: completion_response.to_string(),
1827                    })
1828                    .unwrap()
1829                    .into(),
1830                )
1831                .unwrap())
1832        });
1833
1834        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1835        cx.update(|cx| {
1836            RefreshLlmTokenListener::register(client.clone(), cx);
1837        });
1838        let server = FakeServer::for_client(42, &client, cx).await;
1839        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1840        let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1841
1842        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1843        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1844        let completion_task = zeta.update(cx, |zeta, cx| {
1845            zeta.request_completion(None, &buffer, cursor, false, cx)
1846        });
1847
1848        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1849        server.respond(
1850            token_request.receipt(),
1851            proto::GetLlmTokenResponse { token: "".into() },
1852        );
1853
1854        let completion = completion_task.await.unwrap().unwrap();
1855        buffer.update(cx, |buffer, cx| {
1856            buffer.edit(completion.edits.iter().cloned(), None, cx)
1857        });
1858        assert_eq!(
1859            buffer.read_with(cx, |buffer, _| buffer.text()),
1860            "lorem\nipsum"
1861        );
1862    }
1863
1864    async fn edits_for_prediction(
1865        buffer_content: &str,
1866        completion_response: &str,
1867        cx: &mut TestAppContext,
1868    ) -> Vec<(Range<Point>, String)> {
1869        let completion_response = completion_response.to_string();
1870        let http_client = FakeHttpClient::create(move |_| {
1871            let completion = completion_response.clone();
1872            async move {
1873                Ok(http_client::Response::builder()
1874                    .status(200)
1875                    .body(
1876                        serde_json::to_string(&PredictEditsResponse {
1877                            request_id: Uuid::new_v4(),
1878                            output_excerpt: completion,
1879                        })
1880                        .unwrap()
1881                        .into(),
1882                    )
1883                    .unwrap())
1884            }
1885        });
1886
1887        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1888        cx.update(|cx| {
1889            RefreshLlmTokenListener::register(client.clone(), cx);
1890        });
1891        let server = FakeServer::for_client(42, &client, cx).await;
1892        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1893        let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
1894
1895        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1896        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1897        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1898        let completion_task = zeta.update(cx, |zeta, cx| {
1899            zeta.request_completion(None, &buffer, cursor, false, cx)
1900        });
1901
1902        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1903        server.respond(
1904            token_request.receipt(),
1905            proto::GetLlmTokenResponse { token: "".into() },
1906        );
1907
1908        let completion = completion_task.await.unwrap().unwrap();
1909        completion
1910            .edits
1911            .into_iter()
1912            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1913            .collect::<Vec<_>>()
1914    }
1915
1916    fn to_completion_edits(
1917        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1918        buffer: &Entity<Buffer>,
1919        cx: &App,
1920    ) -> Vec<(Range<Anchor>, String)> {
1921        let buffer = buffer.read(cx);
1922        iterator
1923            .into_iter()
1924            .map(|(range, text)| {
1925                (
1926                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1927                    text,
1928                )
1929            })
1930            .collect()
1931    }
1932
1933    fn from_completion_edits(
1934        editor_edits: &[(Range<Anchor>, String)],
1935        buffer: &Entity<Buffer>,
1936        cx: &App,
1937    ) -> Vec<(Range<usize>, String)> {
1938        let buffer = buffer.read(cx);
1939        editor_edits
1940            .iter()
1941            .map(|(range, text)| {
1942                (
1943                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1944                    text.clone(),
1945                )
1946            })
1947            .collect()
1948    }
1949
1950    #[ctor::ctor]
1951    fn init_logger() {
1952        if std::env::var("RUST_LOG").is_ok() {
1953            env_logger::init();
1954        }
1955    }
1956}