zeta.rs

   1mod completion_diff_element;
   2mod init;
   3mod input_excerpt;
   4mod license_detection;
   5mod onboarding_modal;
   6mod onboarding_telemetry;
   7mod rate_completion_modal;
   8
   9pub(crate) use completion_diff_element::*;
  10use db::kvp::KEY_VALUE_STORE;
  11pub use init::*;
  12use inline_completion::DataCollectionState;
  13use license_detection::LICENSE_FILES_TO_CHECK;
  14pub use license_detection::is_license_eligible_for_data_collection;
  15pub use rate_completion_modal::*;
  16
  17use anyhow::{Context as _, Result, anyhow};
  18use arrayvec::ArrayVec;
  19use client::{Client, UserStore};
  20use collections::{HashMap, HashSet, VecDeque};
  21use futures::AsyncReadExt;
  22use gpui::{
  23    App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
  24    Subscription, Task, WeakEntity, actions,
  25};
  26use http_client::{HttpClient, Method};
  27use input_excerpt::excerpt_for_cursor_position;
  28use language::{
  29    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
  30};
  31use language_model::{LlmApiToken, RefreshLlmTokenListener};
  32use postage::watch;
  33use project::Project;
  34use release_channel::AppVersion;
  35use settings::WorktreeId;
  36use std::str::FromStr;
  37use std::{
  38    borrow::Cow,
  39    cmp,
  40    fmt::Write,
  41    future::Future,
  42    mem,
  43    ops::Range,
  44    path::Path,
  45    rc::Rc,
  46    sync::Arc,
  47    time::{Duration, Instant},
  48};
  49use telemetry_events::InlineCompletionRating;
  50use thiserror::Error;
  51use util::ResultExt;
  52use uuid::Uuid;
  53use workspace::Workspace;
  54use workspace::notifications::{ErrorMessagePrompt, NotificationId};
  55use worktree::Worktree;
  56use zed_llm_client::{
  57    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
  58    PredictEditsResponse,
  59};
  60
  61const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  62const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  63const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  64const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  65const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  66const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  67
  68const MAX_CONTEXT_TOKENS: usize = 150;
  69const MAX_REWRITE_TOKENS: usize = 350;
  70const MAX_EVENT_TOKENS: usize = 500;
  71
  72/// Maximum number of events to track.
  73const MAX_EVENT_COUNT: usize = 16;
  74
  75actions!(edit_prediction, [ClearHistory]);
  76
  77#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  78pub struct InlineCompletionId(Uuid);
  79
  80impl From<InlineCompletionId> for gpui::ElementId {
  81    fn from(value: InlineCompletionId) -> Self {
  82        gpui::ElementId::Uuid(value.0)
  83    }
  84}
  85
  86impl std::fmt::Display for InlineCompletionId {
  87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  88        write!(f, "{}", self.0)
  89    }
  90}
  91
  92#[derive(Clone)]
  93struct ZetaGlobal(Entity<Zeta>);
  94
  95impl Global for ZetaGlobal {}
  96
  97#[derive(Clone)]
  98pub struct InlineCompletion {
  99    id: InlineCompletionId,
 100    path: Arc<Path>,
 101    excerpt_range: Range<usize>,
 102    cursor_offset: usize,
 103    edits: Arc<[(Range<Anchor>, String)]>,
 104    snapshot: BufferSnapshot,
 105    edit_preview: EditPreview,
 106    input_outline: Arc<str>,
 107    input_events: Arc<str>,
 108    input_excerpt: Arc<str>,
 109    output_excerpt: Arc<str>,
 110    request_sent_at: Instant,
 111    response_received_at: Instant,
 112}
 113
 114impl InlineCompletion {
 115    fn latency(&self) -> Duration {
 116        self.response_received_at
 117            .duration_since(self.request_sent_at)
 118    }
 119
 120    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 121        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 122    }
 123}
 124
 125fn interpolate(
 126    old_snapshot: &BufferSnapshot,
 127    new_snapshot: &BufferSnapshot,
 128    current_edits: Arc<[(Range<Anchor>, String)]>,
 129) -> Option<Vec<(Range<Anchor>, String)>> {
 130    let mut edits = Vec::new();
 131
 132    let mut model_edits = current_edits.into_iter().peekable();
 133    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 134        while let Some((model_old_range, _)) = model_edits.peek() {
 135            let model_old_range = model_old_range.to_offset(old_snapshot);
 136            if model_old_range.end < user_edit.old.start {
 137                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 138                edits.push((model_old_range.clone(), model_new_text.clone()));
 139            } else {
 140                break;
 141            }
 142        }
 143
 144        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 145            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 146            if user_edit.old == model_old_offset_range {
 147                let user_new_text = new_snapshot
 148                    .text_for_range(user_edit.new.clone())
 149                    .collect::<String>();
 150
 151                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 152                    if !model_suffix.is_empty() {
 153                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 154                        edits.push((anchor..anchor, model_suffix.to_string()));
 155                    }
 156
 157                    model_edits.next();
 158                    continue;
 159                }
 160            }
 161        }
 162
 163        return None;
 164    }
 165
 166    edits.extend(model_edits.cloned());
 167
 168    if edits.is_empty() { None } else { Some(edits) }
 169}
 170
 171impl std::fmt::Debug for InlineCompletion {
 172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 173        f.debug_struct("InlineCompletion")
 174            .field("id", &self.id)
 175            .field("path", &self.path)
 176            .field("edits", &self.edits)
 177            .finish_non_exhaustive()
 178    }
 179}
 180
 181pub struct Zeta {
 182    workspace: Option<WeakEntity<Workspace>>,
 183    client: Arc<Client>,
 184    events: VecDeque<Event>,
 185    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 186    shown_completions: VecDeque<InlineCompletion>,
 187    rated_completions: HashSet<InlineCompletionId>,
 188    data_collection_choice: Entity<DataCollectionChoice>,
 189    llm_token: LlmApiToken,
 190    _llm_token_subscription: Subscription,
 191    /// Whether the terms of service have been accepted.
 192    tos_accepted: bool,
 193    /// Whether an update to a newer version of Zed is required to continue using Zeta.
 194    update_required: bool,
 195    _user_store_subscription: Subscription,
 196    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 197}
 198
 199impl Zeta {
 200    pub fn global(cx: &mut App) -> Option<Entity<Self>> {
 201        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 202    }
 203
 204    pub fn register(
 205        workspace: Option<WeakEntity<Workspace>>,
 206        worktree: Option<Entity<Worktree>>,
 207        client: Arc<Client>,
 208        user_store: Entity<UserStore>,
 209        cx: &mut App,
 210    ) -> Entity<Self> {
 211        let this = Self::global(cx).unwrap_or_else(|| {
 212            let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
 213            cx.set_global(ZetaGlobal(entity.clone()));
 214            entity
 215        });
 216
 217        this.update(cx, move |this, cx| {
 218            if let Some(worktree) = worktree {
 219                worktree.update(cx, |worktree, cx| {
 220                    this.license_detection_watchers
 221                        .entry(worktree.id())
 222                        .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(worktree, cx)));
 223                });
 224            }
 225        });
 226
 227        this
 228    }
 229
 230    pub fn clear_history(&mut self) {
 231        self.events.clear();
 232    }
 233
 234    fn new(
 235        workspace: Option<WeakEntity<Workspace>>,
 236        client: Arc<Client>,
 237        user_store: Entity<UserStore>,
 238        cx: &mut Context<Self>,
 239    ) -> Self {
 240        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 241
 242        let data_collection_choice = Self::load_data_collection_choices();
 243        let data_collection_choice = cx.new(|_| data_collection_choice);
 244
 245        Self {
 246            workspace,
 247            client,
 248            events: VecDeque::new(),
 249            shown_completions: VecDeque::new(),
 250            rated_completions: HashSet::default(),
 251            registered_buffers: HashMap::default(),
 252            data_collection_choice,
 253            llm_token: LlmApiToken::default(),
 254            _llm_token_subscription: cx.subscribe(
 255                &refresh_llm_token_listener,
 256                |this, _listener, _event, cx| {
 257                    let client = this.client.clone();
 258                    let llm_token = this.llm_token.clone();
 259                    cx.spawn(async move |_this, _cx| {
 260                        llm_token.refresh(&client).await?;
 261                        anyhow::Ok(())
 262                    })
 263                    .detach_and_log_err(cx);
 264                },
 265            ),
 266            tos_accepted: user_store
 267                .read(cx)
 268                .current_user_has_accepted_terms()
 269                .unwrap_or(false),
 270            update_required: false,
 271            _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
 272                match event {
 273                    client::user::Event::PrivateUserInfoUpdated => {
 274                        this.tos_accepted = user_store
 275                            .read(cx)
 276                            .current_user_has_accepted_terms()
 277                            .unwrap_or(false);
 278                    }
 279                    _ => {}
 280                }
 281            }),
 282            license_detection_watchers: HashMap::default(),
 283        }
 284    }
 285
 286    fn push_event(&mut self, event: Event) {
 287        if let Some(Event::BufferChange {
 288            new_snapshot: last_new_snapshot,
 289            timestamp: last_timestamp,
 290            ..
 291        }) = self.events.back_mut()
 292        {
 293            // Coalesce edits for the same buffer when they happen one after the other.
 294            let Event::BufferChange {
 295                old_snapshot,
 296                new_snapshot,
 297                timestamp,
 298            } = &event;
 299
 300            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 301                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 302                && old_snapshot.version == last_new_snapshot.version
 303            {
 304                *last_new_snapshot = new_snapshot.clone();
 305                *last_timestamp = *timestamp;
 306                return;
 307            }
 308        }
 309
 310        self.events.push_back(event);
 311        if self.events.len() >= MAX_EVENT_COUNT {
 312            self.events.drain(..MAX_EVENT_COUNT / 2);
 313        }
 314    }
 315
 316    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
 317        let buffer_id = buffer.entity_id();
 318        let weak_buffer = buffer.downgrade();
 319
 320        if let std::collections::hash_map::Entry::Vacant(entry) =
 321            self.registered_buffers.entry(buffer_id)
 322        {
 323            let snapshot = buffer.read(cx).snapshot();
 324
 325            entry.insert(RegisteredBuffer {
 326                snapshot,
 327                _subscriptions: [
 328                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 329                        this.handle_buffer_event(buffer, event, cx);
 330                    }),
 331                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 332                        this.registered_buffers.remove(&weak_buffer.entity_id());
 333                    }),
 334                ],
 335            });
 336        };
 337    }
 338
 339    fn handle_buffer_event(
 340        &mut self,
 341        buffer: Entity<Buffer>,
 342        event: &language::BufferEvent,
 343        cx: &mut Context<Self>,
 344    ) {
 345        if let language::BufferEvent::Edited = event {
 346            self.report_changes_for_buffer(&buffer, cx);
 347        }
 348    }
 349
 350    fn request_completion_impl<F, R>(
 351        &mut self,
 352        workspace: Option<Entity<Workspace>>,
 353        project: Option<&Entity<Project>>,
 354        buffer: &Entity<Buffer>,
 355        cursor: language::Anchor,
 356        can_collect_data: bool,
 357        cx: &mut Context<Self>,
 358        perform_predict_edits: F,
 359    ) -> Task<Result<Option<InlineCompletion>>>
 360    where
 361        F: FnOnce(PerformPredictEditsParams) -> R + 'static,
 362        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 363    {
 364        let snapshot = self.report_changes_for_buffer(&buffer, cx);
 365        let diagnostic_groups = snapshot.diagnostic_groups(None);
 366        let cursor_point = cursor.to_point(&snapshot);
 367        let cursor_offset = cursor_point.to_offset(&snapshot);
 368        let events = self.events.clone();
 369        let path: Arc<Path> = snapshot
 370            .file()
 371            .map(|f| Arc::from(f.full_path(cx).as_path()))
 372            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 373
 374        let zeta = cx.entity();
 375        let client = self.client.clone();
 376        let llm_token = self.llm_token.clone();
 377        let app_version = AppVersion::global(cx);
 378
 379        let buffer = buffer.clone();
 380
 381        let local_lsp_store =
 382            project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
 383        let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
 384            Some(
 385                diagnostic_groups
 386                    .into_iter()
 387                    .filter_map(|(language_server_id, diagnostic_group)| {
 388                        let language_server =
 389                            local_lsp_store.running_language_server_for_id(language_server_id)?;
 390
 391                        Some((
 392                            language_server.name(),
 393                            diagnostic_group.resolve::<usize>(&snapshot),
 394                        ))
 395                    })
 396                    .collect::<Vec<_>>(),
 397            )
 398        } else {
 399            None
 400        };
 401
 402        cx.spawn(async move |_, cx| {
 403            let request_sent_at = Instant::now();
 404
 405            struct BackgroundValues {
 406                input_events: String,
 407                input_excerpt: String,
 408                speculated_output: String,
 409                editable_range: Range<usize>,
 410                input_outline: String,
 411            }
 412
 413            let values = cx
 414                .background_spawn({
 415                    let snapshot = snapshot.clone();
 416                    let path = path.clone();
 417                    async move {
 418                        let path = path.to_string_lossy();
 419                        let input_excerpt = excerpt_for_cursor_position(
 420                            cursor_point,
 421                            &path,
 422                            &snapshot,
 423                            MAX_REWRITE_TOKENS,
 424                            MAX_CONTEXT_TOKENS,
 425                        );
 426                        let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
 427                        let input_outline = prompt_for_outline(&snapshot);
 428
 429                        anyhow::Ok(BackgroundValues {
 430                            input_events,
 431                            input_excerpt: input_excerpt.prompt,
 432                            speculated_output: input_excerpt.speculated_output,
 433                            editable_range: input_excerpt.editable_range.to_offset(&snapshot),
 434                            input_outline,
 435                        })
 436                    }
 437                })
 438                .await?;
 439
 440            log::debug!(
 441                "Events:\n{}\nExcerpt:\n{:?}",
 442                values.input_events,
 443                values.input_excerpt
 444            );
 445
 446            let body = PredictEditsBody {
 447                input_events: values.input_events.clone(),
 448                input_excerpt: values.input_excerpt.clone(),
 449                speculated_output: Some(values.speculated_output),
 450                outline: Some(values.input_outline.clone()),
 451                can_collect_data,
 452                diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
 453                    diagnostic_groups
 454                        .into_iter()
 455                        .map(|(name, diagnostic_group)| {
 456                            Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
 457                        })
 458                        .collect::<Result<Vec<_>>>()
 459                        .log_err()
 460                }),
 461            };
 462
 463            let response = perform_predict_edits(PerformPredictEditsParams {
 464                client,
 465                llm_token,
 466                app_version,
 467                body,
 468            })
 469            .await;
 470            let response = match response {
 471                Ok(response) => response,
 472                Err(err) => {
 473                    if err.is::<ZedUpdateRequiredError>() {
 474                        cx.update(|cx| {
 475                            zeta.update(cx, |zeta, _cx| {
 476                                zeta.update_required = true;
 477                            });
 478
 479                            if let Some(workspace) = workspace {
 480                                workspace.update(cx, |workspace, cx| {
 481                                    workspace.show_notification(
 482                                        NotificationId::unique::<ZedUpdateRequiredError>(),
 483                                        cx,
 484                                        |cx| {
 485                                            cx.new(|cx| {
 486                                                ErrorMessagePrompt::new(err.to_string(), cx)
 487                                                    .with_link_button(
 488                                                        "Update Zed",
 489                                                        "https://zed.dev/releases",
 490                                                    )
 491                                            })
 492                                        },
 493                                    );
 494                                });
 495                            }
 496                        })
 497                        .ok();
 498                    }
 499
 500                    return Err(err);
 501                }
 502            };
 503
 504            log::debug!("completion response: {}", &response.output_excerpt);
 505
 506            Self::process_completion_response(
 507                response,
 508                buffer,
 509                &snapshot,
 510                values.editable_range,
 511                cursor_offset,
 512                path,
 513                values.input_outline,
 514                values.input_events,
 515                values.input_excerpt,
 516                request_sent_at,
 517                &cx,
 518            )
 519            .await
 520        })
 521    }
 522
 523    // Generates several example completions of various states to fill the Zeta completion modal
 524    #[cfg(any(test, feature = "test-support"))]
 525    pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
 526        use language::Point;
 527
 528        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 529            And maybe a short line
 530
 531            Then a few lines
 532
 533            and then another
 534            "#};
 535
 536        let project = None;
 537        let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
 538        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 539
 540        let completion_tasks = vec![
 541            self.fake_completion(
 542                project,
 543                &buffer,
 544                position,
 545                PredictEditsResponse {
 546                    request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
 547                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 548a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 549[here's an edit]
 550And maybe a short line
 551Then a few lines
 552and then another
 553{EDITABLE_REGION_END_MARKER}
 554                        ", ),
 555                },
 556                cx,
 557            ),
 558            self.fake_completion(
 559                project,
 560                &buffer,
 561                position,
 562                PredictEditsResponse {
 563                    request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
 564                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 565a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 566And maybe a short line
 567[and another edit]
 568Then a few lines
 569and then another
 570{EDITABLE_REGION_END_MARKER}
 571                        "#),
 572                },
 573                cx,
 574            ),
 575            self.fake_completion(
 576                project,
 577                &buffer,
 578                position,
 579                PredictEditsResponse {
 580                    request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
 581                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 582a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 583And maybe a short line
 584
 585Then a few lines
 586
 587and then another
 588{EDITABLE_REGION_END_MARKER}
 589                        "#),
 590                },
 591                cx,
 592            ),
 593            self.fake_completion(
 594                project,
 595                &buffer,
 596                position,
 597                PredictEditsResponse {
 598                    request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
 599                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 600a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 601And maybe a short line
 602
 603Then a few lines
 604
 605and then another
 606{EDITABLE_REGION_END_MARKER}
 607                        "#),
 608                },
 609                cx,
 610            ),
 611            self.fake_completion(
 612                project,
 613                &buffer,
 614                position,
 615                PredictEditsResponse {
 616                    request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
 617                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 618a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 619And maybe a short line
 620Then a few lines
 621[a third completion]
 622and then another
 623{EDITABLE_REGION_END_MARKER}
 624                        "#),
 625                },
 626                cx,
 627            ),
 628            self.fake_completion(
 629                project,
 630                &buffer,
 631                position,
 632                PredictEditsResponse {
 633                    request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
 634                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 635a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 636And maybe a short line
 637and then another
 638[fourth completion example]
 639{EDITABLE_REGION_END_MARKER}
 640                        "#),
 641                },
 642                cx,
 643            ),
 644            self.fake_completion(
 645                project,
 646                &buffer,
 647                position,
 648                PredictEditsResponse {
 649                    request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
 650                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 651a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 652And maybe a short line
 653Then a few lines
 654and then another
 655[fifth and final completion]
 656{EDITABLE_REGION_END_MARKER}
 657                        "#),
 658                },
 659                cx,
 660            ),
 661        ];
 662
 663        cx.spawn(async move |zeta, cx| {
 664            for task in completion_tasks {
 665                task.await.unwrap();
 666            }
 667
 668            zeta.update(cx, |zeta, _cx| {
 669                zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
 670                zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
 671            })
 672            .ok();
 673        })
 674    }
 675
 676    #[cfg(any(test, feature = "test-support"))]
 677    pub fn fake_completion(
 678        &mut self,
 679        project: Option<&Entity<Project>>,
 680        buffer: &Entity<Buffer>,
 681        position: language::Anchor,
 682        response: PredictEditsResponse,
 683        cx: &mut Context<Self>,
 684    ) -> Task<Result<Option<InlineCompletion>>> {
 685        use std::future::ready;
 686
 687        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
 688            ready(Ok(response))
 689        })
 690    }
 691
 692    pub fn request_completion(
 693        &mut self,
 694        project: Option<&Entity<Project>>,
 695        buffer: &Entity<Buffer>,
 696        position: language::Anchor,
 697        can_collect_data: bool,
 698        cx: &mut Context<Self>,
 699    ) -> Task<Result<Option<InlineCompletion>>> {
 700        let workspace = self
 701            .workspace
 702            .as_ref()
 703            .and_then(|workspace| workspace.upgrade());
 704        self.request_completion_impl(
 705            workspace,
 706            project,
 707            buffer,
 708            position,
 709            can_collect_data,
 710            cx,
 711            Self::perform_predict_edits,
 712        )
 713    }
 714
 715    fn perform_predict_edits(
 716        params: PerformPredictEditsParams,
 717    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 718        async move {
 719            let PerformPredictEditsParams {
 720                client,
 721                llm_token,
 722                app_version,
 723                body,
 724                ..
 725            } = params;
 726
 727            let http_client = client.http_client();
 728            let mut token = llm_token.acquire(&client).await?;
 729            let mut did_retry = false;
 730
 731            loop {
 732                let request_builder = http_client::Request::builder().method(Method::POST);
 733                let request_builder =
 734                    if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 735                        request_builder.uri(predict_edits_url)
 736                    } else {
 737                        request_builder.uri(
 738                            http_client
 739                                .build_zed_llm_url("/predict_edits/v2", &[])?
 740                                .as_ref(),
 741                        )
 742                    };
 743                let request = request_builder
 744                    .header("Content-Type", "application/json")
 745                    .header("Authorization", format!("Bearer {}", token))
 746                    .body(serde_json::to_string(&body)?.into())?;
 747
 748                let mut response = http_client.send(request).await?;
 749
 750                if let Some(minimum_required_version) = response
 751                    .headers()
 752                    .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 753                    .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 754                {
 755                    if app_version < minimum_required_version {
 756                        return Err(anyhow!(ZedUpdateRequiredError {
 757                            minimum_version: minimum_required_version
 758                        }));
 759                    }
 760                }
 761
 762                if response.status().is_success() {
 763                    let mut body = String::new();
 764                    response.body_mut().read_to_string(&mut body).await?;
 765                    return Ok(serde_json::from_str(&body)?);
 766                } else if !did_retry
 767                    && response
 768                        .headers()
 769                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 770                        .is_some()
 771                {
 772                    did_retry = true;
 773                    token = llm_token.refresh(&client).await?;
 774                } else {
 775                    let mut body = String::new();
 776                    response.body_mut().read_to_string(&mut body).await?;
 777                    return Err(anyhow!(
 778                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 779                        response.status(),
 780                        body
 781                    ));
 782                }
 783            }
 784        }
 785    }
 786
 787    fn process_completion_response(
 788        prediction_response: PredictEditsResponse,
 789        buffer: Entity<Buffer>,
 790        snapshot: &BufferSnapshot,
 791        editable_range: Range<usize>,
 792        cursor_offset: usize,
 793        path: Arc<Path>,
 794        input_outline: String,
 795        input_events: String,
 796        input_excerpt: String,
 797        request_sent_at: Instant,
 798        cx: &AsyncApp,
 799    ) -> Task<Result<Option<InlineCompletion>>> {
 800        let snapshot = snapshot.clone();
 801        let request_id = prediction_response.request_id;
 802        let output_excerpt = prediction_response.output_excerpt;
 803        cx.spawn(async move |cx| {
 804            let output_excerpt: Arc<str> = output_excerpt.into();
 805
 806            let edits: Arc<[(Range<Anchor>, String)]> = cx
 807                .background_spawn({
 808                    let output_excerpt = output_excerpt.clone();
 809                    let editable_range = editable_range.clone();
 810                    let snapshot = snapshot.clone();
 811                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
 812                })
 813                .await?
 814                .into();
 815
 816            let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
 817                let edits = edits.clone();
 818                |buffer, cx| {
 819                    let new_snapshot = buffer.snapshot();
 820                    let edits: Arc<[(Range<Anchor>, String)]> =
 821                        interpolate(&snapshot, &new_snapshot, edits)?.into();
 822                    Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 823                }
 824            })?
 825            else {
 826                return anyhow::Ok(None);
 827            };
 828
 829            let edit_preview = edit_preview.await;
 830
 831            Ok(Some(InlineCompletion {
 832                id: InlineCompletionId(request_id),
 833                path,
 834                excerpt_range: editable_range,
 835                cursor_offset,
 836                edits,
 837                edit_preview,
 838                snapshot,
 839                input_outline: input_outline.into(),
 840                input_events: input_events.into(),
 841                input_excerpt: input_excerpt.into(),
 842                output_excerpt,
 843                request_sent_at,
 844                response_received_at: Instant::now(),
 845            }))
 846        })
 847    }
 848
 849    fn parse_edits(
 850        output_excerpt: Arc<str>,
 851        editable_range: Range<usize>,
 852        snapshot: &BufferSnapshot,
 853    ) -> Result<Vec<(Range<Anchor>, String)>> {
 854        let content = output_excerpt.replace(CURSOR_MARKER, "");
 855
 856        let start_markers = content
 857            .match_indices(EDITABLE_REGION_START_MARKER)
 858            .collect::<Vec<_>>();
 859        anyhow::ensure!(
 860            start_markers.len() == 1,
 861            "expected exactly one start marker, found {}",
 862            start_markers.len()
 863        );
 864
 865        let end_markers = content
 866            .match_indices(EDITABLE_REGION_END_MARKER)
 867            .collect::<Vec<_>>();
 868        anyhow::ensure!(
 869            end_markers.len() == 1,
 870            "expected exactly one end marker, found {}",
 871            end_markers.len()
 872        );
 873
 874        let sof_markers = content
 875            .match_indices(START_OF_FILE_MARKER)
 876            .collect::<Vec<_>>();
 877        anyhow::ensure!(
 878            sof_markers.len() <= 1,
 879            "expected at most one start-of-file marker, found {}",
 880            sof_markers.len()
 881        );
 882
 883        let codefence_start = start_markers[0].0;
 884        let content = &content[codefence_start..];
 885
 886        let newline_ix = content.find('\n').context("could not find newline")?;
 887        let content = &content[newline_ix + 1..];
 888
 889        let codefence_end = content
 890            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 891            .context("could not find end marker")?;
 892        let new_text = &content[..codefence_end];
 893
 894        let old_text = snapshot
 895            .text_for_range(editable_range.clone())
 896            .collect::<String>();
 897
 898        Ok(Self::compute_edits(
 899            old_text,
 900            new_text,
 901            editable_range.start,
 902            &snapshot,
 903        ))
 904    }
 905
 906    pub fn compute_edits(
 907        old_text: String,
 908        new_text: &str,
 909        offset: usize,
 910        snapshot: &BufferSnapshot,
 911    ) -> Vec<(Range<Anchor>, String)> {
 912        text_diff(&old_text, &new_text)
 913            .into_iter()
 914            .map(|(mut old_range, new_text)| {
 915                old_range.start += offset;
 916                old_range.end += offset;
 917
 918                let prefix_len = common_prefix(
 919                    snapshot.chars_for_range(old_range.clone()),
 920                    new_text.chars(),
 921                );
 922                old_range.start += prefix_len;
 923
 924                let suffix_len = common_prefix(
 925                    snapshot.reversed_chars_for_range(old_range.clone()),
 926                    new_text[prefix_len..].chars().rev(),
 927                );
 928                old_range.end = old_range.end.saturating_sub(suffix_len);
 929
 930                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 931                let range = if old_range.is_empty() {
 932                    let anchor = snapshot.anchor_after(old_range.start);
 933                    anchor..anchor
 934                } else {
 935                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 936                };
 937                (range, new_text)
 938            })
 939            .collect()
 940    }
 941
 942    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 943        self.rated_completions.contains(&completion_id)
 944    }
 945
 946    pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) {
 947        self.shown_completions.push_front(completion.clone());
 948        if self.shown_completions.len() > 50 {
 949            let completion = self.shown_completions.pop_back().unwrap();
 950            self.rated_completions.remove(&completion.id);
 951        }
 952        cx.notify();
 953    }
 954
 955    pub fn rate_completion(
 956        &mut self,
 957        completion: &InlineCompletion,
 958        rating: InlineCompletionRating,
 959        feedback: String,
 960        cx: &mut Context<Self>,
 961    ) {
 962        self.rated_completions.insert(completion.id);
 963        telemetry::event!(
 964            "Edit Prediction Rated",
 965            rating,
 966            input_events = completion.input_events,
 967            input_excerpt = completion.input_excerpt,
 968            input_outline = completion.input_outline,
 969            output_excerpt = completion.output_excerpt,
 970            feedback
 971        );
 972        self.client.telemetry().flush_events();
 973        cx.notify();
 974    }
 975
 976    pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 977        self.shown_completions.iter()
 978    }
 979
 980    pub fn shown_completions_len(&self) -> usize {
 981        self.shown_completions.len()
 982    }
 983
 984    fn report_changes_for_buffer(
 985        &mut self,
 986        buffer: &Entity<Buffer>,
 987        cx: &mut Context<Self>,
 988    ) -> BufferSnapshot {
 989        self.register_buffer(buffer, cx);
 990
 991        let registered_buffer = self
 992            .registered_buffers
 993            .get_mut(&buffer.entity_id())
 994            .unwrap();
 995        let new_snapshot = buffer.read(cx).snapshot();
 996
 997        if new_snapshot.version != registered_buffer.snapshot.version {
 998            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 999            self.push_event(Event::BufferChange {
1000                old_snapshot,
1001                new_snapshot: new_snapshot.clone(),
1002                timestamp: Instant::now(),
1003            });
1004        }
1005
1006        new_snapshot
1007    }
1008
1009    fn load_data_collection_choices() -> DataCollectionChoice {
1010        let choice = KEY_VALUE_STORE
1011            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1012            .log_err()
1013            .flatten();
1014
1015        match choice.as_deref() {
1016            Some("true") => DataCollectionChoice::Enabled,
1017            Some("false") => DataCollectionChoice::Disabled,
1018            Some(_) => {
1019                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1020                DataCollectionChoice::NotAnswered
1021            }
1022            None => DataCollectionChoice::NotAnswered,
1023        }
1024    }
1025}
1026
1027struct PerformPredictEditsParams {
1028    pub client: Arc<Client>,
1029    pub llm_token: LlmApiToken,
1030    pub app_version: SemanticVersion,
1031    pub body: PredictEditsBody,
1032}
1033
1034#[derive(Error, Debug)]
1035#[error(
1036    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1037)]
1038pub struct ZedUpdateRequiredError {
1039    minimum_version: SemanticVersion,
1040}
1041
1042struct LicenseDetectionWatcher {
1043    is_open_source_rx: watch::Receiver<bool>,
1044    _is_open_source_task: Task<()>,
1045}
1046
1047impl LicenseDetectionWatcher {
1048    pub fn new(worktree: &Worktree, cx: &mut Context<Worktree>) -> Self {
1049        let (mut is_open_source_tx, is_open_source_rx) = watch::channel_with::<bool>(false);
1050
1051        // Check if worktree is a single file, if so we do not need to check for a LICENSE file
1052        let task = if worktree.abs_path().is_file() {
1053            Task::ready(())
1054        } else {
1055            let loaded_files = LICENSE_FILES_TO_CHECK
1056                .iter()
1057                .map(Path::new)
1058                .map(|file| worktree.load_file(file, cx))
1059                .collect::<ArrayVec<_, { LICENSE_FILES_TO_CHECK.len() }>>();
1060
1061            cx.background_spawn(async move {
1062                for loaded_file in loaded_files.into_iter() {
1063                    let Ok(loaded_file) = loaded_file.await else {
1064                        continue;
1065                    };
1066
1067                    let path = &loaded_file.file.path;
1068                    if is_license_eligible_for_data_collection(&loaded_file.text) {
1069                        log::info!("detected '{path:?}' as open source license");
1070                        *is_open_source_tx.borrow_mut() = true;
1071                    } else {
1072                        log::info!("didn't detect '{path:?}' as open source license");
1073                    }
1074
1075                    // stop on the first license that successfully read
1076                    return;
1077                }
1078
1079                log::debug!("didn't find a license file to check, assuming closed source");
1080            })
1081        };
1082
1083        Self {
1084            is_open_source_rx,
1085            _is_open_source_task: task,
1086        }
1087    }
1088
1089    /// Answers false until we find out it's open source
1090    pub fn is_project_open_source(&self) -> bool {
1091        *self.is_open_source_rx.borrow()
1092    }
1093}
1094
1095fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
1096    a.zip(b)
1097        .take_while(|(a, b)| a == b)
1098        .map(|(a, _)| a.len_utf8())
1099        .sum()
1100}
1101
1102fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
1103    let mut input_outline = String::new();
1104
1105    writeln!(
1106        input_outline,
1107        "```{}",
1108        snapshot
1109            .file()
1110            .map_or(Cow::Borrowed("untitled"), |file| file
1111                .path()
1112                .to_string_lossy())
1113    )
1114    .unwrap();
1115
1116    if let Some(outline) = snapshot.outline(None) {
1117        for item in &outline.items {
1118            let spacing = " ".repeat(item.depth);
1119            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
1120        }
1121    }
1122
1123    writeln!(input_outline, "```").unwrap();
1124
1125    input_outline
1126}
1127
1128fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
1129    let mut result = String::new();
1130    for event in events.iter().rev() {
1131        let event_string = event.to_prompt();
1132        let event_tokens = tokens_for_bytes(event_string.len());
1133        if event_tokens > remaining_tokens {
1134            break;
1135        }
1136
1137        if !result.is_empty() {
1138            result.insert_str(0, "\n\n");
1139        }
1140        result.insert_str(0, &event_string);
1141        remaining_tokens -= event_tokens;
1142    }
1143    result
1144}
1145
1146struct RegisteredBuffer {
1147    snapshot: BufferSnapshot,
1148    _subscriptions: [gpui::Subscription; 2],
1149}
1150
1151#[derive(Clone)]
1152enum Event {
1153    BufferChange {
1154        old_snapshot: BufferSnapshot,
1155        new_snapshot: BufferSnapshot,
1156        timestamp: Instant,
1157    },
1158}
1159
1160impl Event {
1161    fn to_prompt(&self) -> String {
1162        match self {
1163            Event::BufferChange {
1164                old_snapshot,
1165                new_snapshot,
1166                ..
1167            } => {
1168                let mut prompt = String::new();
1169
1170                let old_path = old_snapshot
1171                    .file()
1172                    .map(|f| f.path().as_ref())
1173                    .unwrap_or(Path::new("untitled"));
1174                let new_path = new_snapshot
1175                    .file()
1176                    .map(|f| f.path().as_ref())
1177                    .unwrap_or(Path::new("untitled"));
1178                if old_path != new_path {
1179                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
1180                }
1181
1182                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
1183                if !diff.is_empty() {
1184                    write!(
1185                        prompt,
1186                        "User edited {:?}:\n```diff\n{}\n```",
1187                        new_path, diff
1188                    )
1189                    .unwrap();
1190                }
1191
1192                prompt
1193            }
1194        }
1195    }
1196}
1197
1198#[derive(Debug, Clone)]
1199struct CurrentInlineCompletion {
1200    buffer_id: EntityId,
1201    completion: InlineCompletion,
1202}
1203
1204impl CurrentInlineCompletion {
1205    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
1206        if self.buffer_id != old_completion.buffer_id {
1207            return true;
1208        }
1209
1210        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
1211            return true;
1212        };
1213        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
1214            return false;
1215        };
1216
1217        if old_edits.len() == 1 && new_edits.len() == 1 {
1218            let (old_range, old_text) = &old_edits[0];
1219            let (new_range, new_text) = &new_edits[0];
1220            new_range == old_range && new_text.starts_with(old_text)
1221        } else {
1222            true
1223        }
1224    }
1225}
1226
1227struct PendingCompletion {
1228    id: usize,
1229    _task: Task<()>,
1230}
1231
1232#[derive(Debug, Clone, Copy)]
1233pub enum DataCollectionChoice {
1234    NotAnswered,
1235    Enabled,
1236    Disabled,
1237}
1238
1239impl DataCollectionChoice {
1240    pub fn is_enabled(self) -> bool {
1241        match self {
1242            Self::Enabled => true,
1243            Self::NotAnswered | Self::Disabled => false,
1244        }
1245    }
1246
1247    pub fn is_answered(self) -> bool {
1248        match self {
1249            Self::Enabled | Self::Disabled => true,
1250            Self::NotAnswered => false,
1251        }
1252    }
1253
1254    pub fn toggle(&self) -> DataCollectionChoice {
1255        match self {
1256            Self::Enabled => Self::Disabled,
1257            Self::Disabled => Self::Enabled,
1258            Self::NotAnswered => Self::Enabled,
1259        }
1260    }
1261}
1262
1263impl From<bool> for DataCollectionChoice {
1264    fn from(value: bool) -> Self {
1265        match value {
1266            true => DataCollectionChoice::Enabled,
1267            false => DataCollectionChoice::Disabled,
1268        }
1269    }
1270}
1271
1272pub struct ProviderDataCollection {
1273    /// When set to None, data collection is not possible in the provider buffer
1274    choice: Option<Entity<DataCollectionChoice>>,
1275    license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
1276}
1277
1278impl ProviderDataCollection {
1279    pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
1280        let choice_and_watcher = buffer.and_then(|buffer| {
1281            let file = buffer.read(cx).file()?;
1282
1283            if !file.is_local() || file.is_private() {
1284                return None;
1285            }
1286
1287            let zeta = zeta.read(cx);
1288            let choice = zeta.data_collection_choice.clone();
1289
1290            let license_detection_watcher = zeta
1291                .license_detection_watchers
1292                .get(&file.worktree_id(cx))
1293                .cloned()?;
1294
1295            Some((choice, license_detection_watcher))
1296        });
1297
1298        if let Some((choice, watcher)) = choice_and_watcher {
1299            ProviderDataCollection {
1300                choice: Some(choice),
1301                license_detection_watcher: Some(watcher),
1302            }
1303        } else {
1304            ProviderDataCollection {
1305                choice: None,
1306                license_detection_watcher: None,
1307            }
1308        }
1309    }
1310
1311    pub fn can_collect_data(&self, cx: &App) -> bool {
1312        self.is_data_collection_enabled(cx) && self.is_project_open_source()
1313    }
1314
1315    pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
1316        self.choice
1317            .as_ref()
1318            .is_some_and(|choice| choice.read(cx).is_enabled())
1319    }
1320
1321    fn is_project_open_source(&self) -> bool {
1322        self.license_detection_watcher
1323            .as_ref()
1324            .is_some_and(|watcher| watcher.is_project_open_source())
1325    }
1326
1327    pub fn toggle(&mut self, cx: &mut App) {
1328        if let Some(choice) = self.choice.as_mut() {
1329            let new_choice = choice.update(cx, |choice, _cx| {
1330                let new_choice = choice.toggle();
1331                *choice = new_choice;
1332                new_choice
1333            });
1334
1335            db::write_and_log(cx, move || {
1336                KEY_VALUE_STORE.write_kvp(
1337                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1338                    new_choice.is_enabled().to_string(),
1339                )
1340            });
1341        }
1342    }
1343}
1344
1345pub struct ZetaInlineCompletionProvider {
1346    zeta: Entity<Zeta>,
1347    pending_completions: ArrayVec<PendingCompletion, 2>,
1348    next_pending_completion_id: usize,
1349    current_completion: Option<CurrentInlineCompletion>,
1350    /// None if this is entirely disabled for this provider
1351    provider_data_collection: ProviderDataCollection,
1352    last_request_timestamp: Instant,
1353}
1354
1355impl ZetaInlineCompletionProvider {
1356    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1357
1358    pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self {
1359        Self {
1360            zeta,
1361            pending_completions: ArrayVec::new(),
1362            next_pending_completion_id: 0,
1363            current_completion: None,
1364            provider_data_collection,
1365            last_request_timestamp: Instant::now(),
1366        }
1367    }
1368}
1369
1370impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider {
1371    fn name() -> &'static str {
1372        "zed-predict"
1373    }
1374
1375    fn display_name() -> &'static str {
1376        "Zed's Edit Predictions"
1377    }
1378
1379    fn show_completions_in_menu() -> bool {
1380        true
1381    }
1382
1383    fn show_tab_accept_marker() -> bool {
1384        true
1385    }
1386
1387    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
1388        let is_project_open_source = self.provider_data_collection.is_project_open_source();
1389
1390        if self.provider_data_collection.is_data_collection_enabled(cx) {
1391            DataCollectionState::Enabled {
1392                is_project_open_source,
1393            }
1394        } else {
1395            DataCollectionState::Disabled {
1396                is_project_open_source,
1397            }
1398        }
1399    }
1400
1401    fn toggle_data_collection(&mut self, cx: &mut App) {
1402        self.provider_data_collection.toggle(cx);
1403    }
1404
1405    fn is_enabled(
1406        &self,
1407        _buffer: &Entity<Buffer>,
1408        _cursor_position: language::Anchor,
1409        _cx: &App,
1410    ) -> bool {
1411        true
1412    }
1413
1414    fn needs_terms_acceptance(&self, cx: &App) -> bool {
1415        !self.zeta.read(cx).tos_accepted
1416    }
1417
1418    fn is_refreshing(&self) -> bool {
1419        !self.pending_completions.is_empty()
1420    }
1421
1422    fn refresh(
1423        &mut self,
1424        project: Option<Entity<Project>>,
1425        buffer: Entity<Buffer>,
1426        position: language::Anchor,
1427        _debounce: bool,
1428        cx: &mut Context<Self>,
1429    ) {
1430        if !self.zeta.read(cx).tos_accepted {
1431            return;
1432        }
1433
1434        if self.zeta.read(cx).update_required {
1435            return;
1436        }
1437
1438        if let Some(current_completion) = self.current_completion.as_ref() {
1439            let snapshot = buffer.read(cx).snapshot();
1440            if current_completion
1441                .completion
1442                .interpolate(&snapshot)
1443                .is_some()
1444            {
1445                return;
1446            }
1447        }
1448
1449        let pending_completion_id = self.next_pending_completion_id;
1450        self.next_pending_completion_id += 1;
1451        let can_collect_data = self.provider_data_collection.can_collect_data(cx);
1452        let last_request_timestamp = self.last_request_timestamp;
1453
1454        let task = cx.spawn(async move |this, cx| {
1455            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
1456                .checked_duration_since(Instant::now())
1457            {
1458                cx.background_executor().timer(timeout).await;
1459            }
1460
1461            let completion_request = this.update(cx, |this, cx| {
1462                this.last_request_timestamp = Instant::now();
1463                this.zeta.update(cx, |zeta, cx| {
1464                    zeta.request_completion(
1465                        project.as_ref(),
1466                        &buffer,
1467                        position,
1468                        can_collect_data,
1469                        cx,
1470                    )
1471                })
1472            });
1473
1474            let completion = match completion_request {
1475                Ok(completion_request) => {
1476                    let completion_request = completion_request.await;
1477                    completion_request.map(|c| {
1478                        c.map(|completion| CurrentInlineCompletion {
1479                            buffer_id: buffer.entity_id(),
1480                            completion,
1481                        })
1482                    })
1483                }
1484                Err(error) => Err(error),
1485            };
1486            let Some(new_completion) = completion
1487                .context("edit prediction failed")
1488                .log_err()
1489                .flatten()
1490            else {
1491                this.update(cx, |this, cx| {
1492                    if this.pending_completions[0].id == pending_completion_id {
1493                        this.pending_completions.remove(0);
1494                    } else {
1495                        this.pending_completions.clear();
1496                    }
1497
1498                    cx.notify();
1499                })
1500                .ok();
1501                return;
1502            };
1503
1504            this.update(cx, |this, cx| {
1505                if this.pending_completions[0].id == pending_completion_id {
1506                    this.pending_completions.remove(0);
1507                } else {
1508                    this.pending_completions.clear();
1509                }
1510
1511                if let Some(old_completion) = this.current_completion.as_ref() {
1512                    let snapshot = buffer.read(cx).snapshot();
1513                    if new_completion.should_replace_completion(&old_completion, &snapshot) {
1514                        this.zeta.update(cx, |zeta, cx| {
1515                            zeta.completion_shown(&new_completion.completion, cx);
1516                        });
1517                        this.current_completion = Some(new_completion);
1518                    }
1519                } else {
1520                    this.zeta.update(cx, |zeta, cx| {
1521                        zeta.completion_shown(&new_completion.completion, cx);
1522                    });
1523                    this.current_completion = Some(new_completion);
1524                }
1525
1526                cx.notify();
1527            })
1528            .ok();
1529        });
1530
1531        // We always maintain at most two pending completions. When we already
1532        // have two, we replace the newest one.
1533        if self.pending_completions.len() <= 1 {
1534            self.pending_completions.push(PendingCompletion {
1535                id: pending_completion_id,
1536                _task: task,
1537            });
1538        } else if self.pending_completions.len() == 2 {
1539            self.pending_completions.pop();
1540            self.pending_completions.push(PendingCompletion {
1541                id: pending_completion_id,
1542                _task: task,
1543            });
1544        }
1545    }
1546
1547    fn cycle(
1548        &mut self,
1549        _buffer: Entity<Buffer>,
1550        _cursor_position: language::Anchor,
1551        _direction: inline_completion::Direction,
1552        _cx: &mut Context<Self>,
1553    ) {
1554        // Right now we don't support cycling.
1555    }
1556
1557    fn accept(&mut self, _cx: &mut Context<Self>) {
1558        self.pending_completions.clear();
1559    }
1560
1561    fn discard(&mut self, _cx: &mut Context<Self>) {
1562        self.pending_completions.clear();
1563        self.current_completion.take();
1564    }
1565
1566    fn suggest(
1567        &mut self,
1568        buffer: &Entity<Buffer>,
1569        cursor_position: language::Anchor,
1570        cx: &mut Context<Self>,
1571    ) -> Option<inline_completion::InlineCompletion> {
1572        let CurrentInlineCompletion {
1573            buffer_id,
1574            completion,
1575            ..
1576        } = self.current_completion.as_mut()?;
1577
1578        // Invalidate previous completion if it was generated for a different buffer.
1579        if *buffer_id != buffer.entity_id() {
1580            self.current_completion.take();
1581            return None;
1582        }
1583
1584        let buffer = buffer.read(cx);
1585        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1586            self.current_completion.take();
1587            return None;
1588        };
1589
1590        let cursor_row = cursor_position.to_point(buffer).row;
1591        let (closest_edit_ix, (closest_edit_range, _)) =
1592            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1593                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1594                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1595                cmp::min(distance_from_start, distance_from_end)
1596            })?;
1597
1598        let mut edit_start_ix = closest_edit_ix;
1599        for (range, _) in edits[..edit_start_ix].iter().rev() {
1600            let distance_from_closest_edit =
1601                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1602            if distance_from_closest_edit <= 1 {
1603                edit_start_ix -= 1;
1604            } else {
1605                break;
1606            }
1607        }
1608
1609        let mut edit_end_ix = closest_edit_ix + 1;
1610        for (range, _) in &edits[edit_end_ix..] {
1611            let distance_from_closest_edit =
1612                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1613            if distance_from_closest_edit <= 1 {
1614                edit_end_ix += 1;
1615            } else {
1616                break;
1617            }
1618        }
1619
1620        Some(inline_completion::InlineCompletion {
1621            id: Some(completion.id.to_string().into()),
1622            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1623            edit_preview: Some(completion.edit_preview.clone()),
1624        })
1625    }
1626}
1627
1628fn tokens_for_bytes(bytes: usize) -> usize {
1629    /// Typical number of string bytes per token for the purposes of limiting model input. This is
1630    /// intentionally low to err on the side of underestimating limits.
1631    const BYTES_PER_TOKEN_GUESS: usize = 3;
1632    bytes / BYTES_PER_TOKEN_GUESS
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637    use client::test::FakeServer;
1638    use clock::FakeSystemClock;
1639    use gpui::TestAppContext;
1640    use http_client::FakeHttpClient;
1641    use indoc::indoc;
1642    use language::Point;
1643    use rpc::proto;
1644    use settings::SettingsStore;
1645
1646    use super::*;
1647
1648    #[gpui::test]
1649    async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
1650        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1651        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1652            to_completion_edits(
1653                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1654                &buffer,
1655                cx,
1656            )
1657            .into()
1658        });
1659
1660        let edit_preview = cx
1661            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1662            .await;
1663
1664        let completion = InlineCompletion {
1665            edits,
1666            edit_preview,
1667            path: Path::new("").into(),
1668            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1669            id: InlineCompletionId(Uuid::new_v4()),
1670            excerpt_range: 0..0,
1671            cursor_offset: 0,
1672            input_outline: "".into(),
1673            input_events: "".into(),
1674            input_excerpt: "".into(),
1675            output_excerpt: "".into(),
1676            request_sent_at: Instant::now(),
1677            response_received_at: Instant::now(),
1678        };
1679
1680        cx.update(|cx| {
1681            assert_eq!(
1682                from_completion_edits(
1683                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1684                    &buffer,
1685                    cx
1686                ),
1687                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1688            );
1689
1690            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1691            assert_eq!(
1692                from_completion_edits(
1693                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1694                    &buffer,
1695                    cx
1696                ),
1697                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1698            );
1699
1700            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1701            assert_eq!(
1702                from_completion_edits(
1703                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1704                    &buffer,
1705                    cx
1706                ),
1707                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1708            );
1709
1710            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1711            assert_eq!(
1712                from_completion_edits(
1713                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1714                    &buffer,
1715                    cx
1716                ),
1717                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1718            );
1719
1720            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1721            assert_eq!(
1722                from_completion_edits(
1723                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1724                    &buffer,
1725                    cx
1726                ),
1727                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1728            );
1729
1730            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1731            assert_eq!(
1732                from_completion_edits(
1733                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1734                    &buffer,
1735                    cx
1736                ),
1737                vec![(9..11, "".to_string())]
1738            );
1739
1740            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1741            assert_eq!(
1742                from_completion_edits(
1743                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1744                    &buffer,
1745                    cx
1746                ),
1747                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1748            );
1749
1750            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1751            assert_eq!(
1752                from_completion_edits(
1753                    &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1754                    &buffer,
1755                    cx
1756                ),
1757                vec![(4..4, "M".to_string())]
1758            );
1759
1760            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1761            assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1762        })
1763    }
1764
1765    #[gpui::test]
1766    async fn test_clean_up_diff(cx: &mut TestAppContext) {
1767        cx.update(|cx| {
1768            let settings_store = SettingsStore::test(cx);
1769            cx.set_global(settings_store);
1770            client::init_settings(cx);
1771        });
1772
1773        let edits = edits_for_prediction(
1774            indoc! {"
1775                fn main() {
1776                    let word_1 = \"lorem\";
1777                    let range = word.len()..word.len();
1778                }
1779            "},
1780            indoc! {"
1781                <|editable_region_start|>
1782                fn main() {
1783                    let word_1 = \"lorem\";
1784                    let range = word_1.len()..word_1.len();
1785                }
1786
1787                <|editable_region_end|>
1788            "},
1789            cx,
1790        )
1791        .await;
1792        assert_eq!(
1793            edits,
1794            [
1795                (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
1796                (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
1797            ]
1798        );
1799
1800        let edits = edits_for_prediction(
1801            indoc! {"
1802                fn main() {
1803                    let story = \"the quick\"
1804                }
1805            "},
1806            indoc! {"
1807                <|editable_region_start|>
1808                fn main() {
1809                    let story = \"the quick brown fox jumps over the lazy dog\";
1810                }
1811
1812                <|editable_region_end|>
1813            "},
1814            cx,
1815        )
1816        .await;
1817        assert_eq!(
1818            edits,
1819            [
1820                (
1821                    Point::new(1, 26)..Point::new(1, 26),
1822                    " brown fox jumps over the lazy dog".to_string()
1823                ),
1824                (Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
1825            ]
1826        );
1827    }
1828
1829    #[gpui::test]
1830    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1831        cx.update(|cx| {
1832            let settings_store = SettingsStore::test(cx);
1833            cx.set_global(settings_store);
1834            client::init_settings(cx);
1835        });
1836
1837        let buffer_content = "lorem\n";
1838        let completion_response = indoc! {"
1839            ```animals.js
1840            <|start_of_file|>
1841            <|editable_region_start|>
1842            lorem
1843            ipsum
1844            <|editable_region_end|>
1845            ```"};
1846
1847        let http_client = FakeHttpClient::create(move |_| async move {
1848            Ok(http_client::Response::builder()
1849                .status(200)
1850                .body(
1851                    serde_json::to_string(&PredictEditsResponse {
1852                        request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
1853                            .unwrap(),
1854                        output_excerpt: completion_response.to_string(),
1855                    })
1856                    .unwrap()
1857                    .into(),
1858                )
1859                .unwrap())
1860        });
1861
1862        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1863        cx.update(|cx| {
1864            RefreshLlmTokenListener::register(client.clone(), cx);
1865        });
1866        let server = FakeServer::for_client(42, &client, cx).await;
1867        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868        let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1869
1870        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1871        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1872        let completion_task = zeta.update(cx, |zeta, cx| {
1873            zeta.request_completion(None, &buffer, cursor, false, cx)
1874        });
1875
1876        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1877        server.respond(
1878            token_request.receipt(),
1879            proto::GetLlmTokenResponse { token: "".into() },
1880        );
1881
1882        let completion = completion_task.await.unwrap().unwrap();
1883        buffer.update(cx, |buffer, cx| {
1884            buffer.edit(completion.edits.iter().cloned(), None, cx)
1885        });
1886        assert_eq!(
1887            buffer.read_with(cx, |buffer, _| buffer.text()),
1888            "lorem\nipsum"
1889        );
1890    }
1891
1892    async fn edits_for_prediction(
1893        buffer_content: &str,
1894        completion_response: &str,
1895        cx: &mut TestAppContext,
1896    ) -> Vec<(Range<Point>, String)> {
1897        let completion_response = completion_response.to_string();
1898        let http_client = FakeHttpClient::create(move |_| {
1899            let completion = completion_response.clone();
1900            async move {
1901                Ok(http_client::Response::builder()
1902                    .status(200)
1903                    .body(
1904                        serde_json::to_string(&PredictEditsResponse {
1905                            request_id: Uuid::new_v4(),
1906                            output_excerpt: completion,
1907                        })
1908                        .unwrap()
1909                        .into(),
1910                    )
1911                    .unwrap())
1912            }
1913        });
1914
1915        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1916        cx.update(|cx| {
1917            RefreshLlmTokenListener::register(client.clone(), cx);
1918        });
1919        let server = FakeServer::for_client(42, &client, cx).await;
1920        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1921        let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
1922
1923        let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1924        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
1925        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1926        let completion_task = zeta.update(cx, |zeta, cx| {
1927            zeta.request_completion(None, &buffer, cursor, false, cx)
1928        });
1929
1930        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1931        server.respond(
1932            token_request.receipt(),
1933            proto::GetLlmTokenResponse { token: "".into() },
1934        );
1935
1936        let completion = completion_task.await.unwrap().unwrap();
1937        completion
1938            .edits
1939            .into_iter()
1940            .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
1941            .collect::<Vec<_>>()
1942    }
1943
1944    fn to_completion_edits(
1945        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1946        buffer: &Entity<Buffer>,
1947        cx: &App,
1948    ) -> Vec<(Range<Anchor>, String)> {
1949        let buffer = buffer.read(cx);
1950        iterator
1951            .into_iter()
1952            .map(|(range, text)| {
1953                (
1954                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1955                    text,
1956                )
1957            })
1958            .collect()
1959    }
1960
1961    fn from_completion_edits(
1962        editor_edits: &[(Range<Anchor>, String)],
1963        buffer: &Entity<Buffer>,
1964        cx: &App,
1965    ) -> Vec<(Range<usize>, String)> {
1966        let buffer = buffer.read(cx);
1967        editor_edits
1968            .iter()
1969            .map(|(range, text)| {
1970                (
1971                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1972                    text.clone(),
1973                )
1974            })
1975            .collect()
1976    }
1977
1978    #[ctor::ctor]
1979    fn init_logger() {
1980        if std::env::var("RUST_LOG").is_ok() {
1981            env_logger::init();
1982        }
1983    }
1984}