zeta.rs

   1mod rate_completion_modal;
   2
   3pub use rate_completion_modal::*;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use arrayvec::ArrayVec;
   7use client::Client;
   8use collections::{HashMap, HashSet, VecDeque};
   9use futures::AsyncReadExt;
  10use gpui::{
  11    actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
  12    Subscription, Task,
  13};
  14use http_client::{HttpClient, Method};
  15use language::{
  16    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
  17    Point, ToOffset, ToPoint,
  18};
  19use language_models::LlmApiToken;
  20use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  21use std::{
  22    borrow::Cow,
  23    cmp,
  24    fmt::Write,
  25    future::Future,
  26    mem,
  27    ops::Range,
  28    path::Path,
  29    sync::Arc,
  30    time::{Duration, Instant},
  31};
  32use telemetry_events::InlineCompletionRating;
  33use uuid::Uuid;
  34
  35const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  36const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  37const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  38const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  40
  41actions!(zeta, [ClearHistory]);
  42
  43#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  44pub struct InlineCompletionId(Uuid);
  45
  46impl From<InlineCompletionId> for gpui::ElementId {
  47    fn from(value: InlineCompletionId) -> Self {
  48        gpui::ElementId::Uuid(value.0)
  49    }
  50}
  51
  52impl std::fmt::Display for InlineCompletionId {
  53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  54        write!(f, "{}", self.0)
  55    }
  56}
  57
  58impl InlineCompletionId {
  59    fn new() -> Self {
  60        Self(Uuid::new_v4())
  61    }
  62}
  63
  64#[derive(Clone)]
  65struct ZetaGlobal(Model<Zeta>);
  66
  67impl Global for ZetaGlobal {}
  68
  69#[derive(Clone)]
  70pub struct InlineCompletion {
  71    id: InlineCompletionId,
  72    path: Arc<Path>,
  73    excerpt_range: Range<usize>,
  74    edits: Arc<[(Range<Anchor>, String)]>,
  75    snapshot: BufferSnapshot,
  76    input_outline: Arc<str>,
  77    input_events: Arc<str>,
  78    input_excerpt: Arc<str>,
  79    output_excerpt: Arc<str>,
  80    request_sent_at: Instant,
  81    response_received_at: Instant,
  82}
  83
  84impl InlineCompletion {
  85    fn latency(&self) -> Duration {
  86        self.response_received_at
  87            .duration_since(self.request_sent_at)
  88    }
  89
  90    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
  91        let mut edits = Vec::new();
  92
  93        let mut user_edits = new_snapshot
  94            .edits_since::<usize>(&self.snapshot.version)
  95            .peekable();
  96        for (model_old_range, model_new_text) in self.edits.iter() {
  97            let model_offset_range = model_old_range.to_offset(&self.snapshot);
  98            while let Some(next_user_edit) = user_edits.peek() {
  99                if next_user_edit.old.end < model_offset_range.start {
 100                    user_edits.next();
 101                } else {
 102                    break;
 103                }
 104            }
 105
 106            if let Some(user_edit) = user_edits.peek() {
 107                if user_edit.old.start > model_offset_range.end {
 108                    edits.push((model_old_range.clone(), model_new_text.clone()));
 109                } else if user_edit.old == model_offset_range {
 110                    let user_new_text = new_snapshot
 111                        .text_for_range(user_edit.new.clone())
 112                        .collect::<String>();
 113
 114                    if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 115                        if !model_suffix.is_empty() {
 116                            edits.push((
 117                                new_snapshot.anchor_after(user_edit.new.end)
 118                                    ..new_snapshot.anchor_before(user_edit.new.end),
 119                                model_suffix.into(),
 120                            ));
 121                        }
 122
 123                        user_edits.next();
 124                    } else {
 125                        return None;
 126                    }
 127                } else {
 128                    return None;
 129                }
 130            } else {
 131                edits.push((model_old_range.clone(), model_new_text.clone()));
 132            }
 133        }
 134
 135        if edits.is_empty() {
 136            None
 137        } else {
 138            Some(edits)
 139        }
 140    }
 141}
 142
 143impl std::fmt::Debug for InlineCompletion {
 144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 145        f.debug_struct("InlineCompletion")
 146            .field("id", &self.id)
 147            .field("path", &self.path)
 148            .field("edits", &self.edits)
 149            .finish_non_exhaustive()
 150    }
 151}
 152
 153pub struct Zeta {
 154    client: Arc<Client>,
 155    events: VecDeque<Event>,
 156    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 157    recent_completions: VecDeque<InlineCompletion>,
 158    rated_completions: HashSet<InlineCompletionId>,
 159    shown_completions: HashSet<InlineCompletionId>,
 160    llm_token: LlmApiToken,
 161    _llm_token_subscription: Subscription,
 162}
 163
 164impl Zeta {
 165    pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
 166        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 167    }
 168
 169    pub fn register(client: Arc<Client>, cx: &mut AppContext) -> Model<Self> {
 170        Self::global(cx).unwrap_or_else(|| {
 171            let model = cx.new_model(|cx| Self::new(client, cx));
 172            cx.set_global(ZetaGlobal(model.clone()));
 173            model
 174        })
 175    }
 176
 177    pub fn clear_history(&mut self) {
 178        self.events.clear();
 179    }
 180
 181    fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
 182        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 183
 184        Self {
 185            client,
 186            events: VecDeque::new(),
 187            recent_completions: VecDeque::new(),
 188            rated_completions: HashSet::default(),
 189            shown_completions: HashSet::default(),
 190            registered_buffers: HashMap::default(),
 191            llm_token: LlmApiToken::default(),
 192            _llm_token_subscription: cx.subscribe(
 193                &refresh_llm_token_listener,
 194                |this, _listener, _event, cx| {
 195                    let client = this.client.clone();
 196                    let llm_token = this.llm_token.clone();
 197                    cx.spawn(|_this, _cx| async move {
 198                        llm_token.refresh(&client).await?;
 199                        anyhow::Ok(())
 200                    })
 201                    .detach_and_log_err(cx);
 202                },
 203            ),
 204        }
 205    }
 206
 207    fn push_event(&mut self, event: Event) {
 208        if let Some(Event::BufferChange {
 209            new_snapshot: last_new_snapshot,
 210            timestamp: last_timestamp,
 211            ..
 212        }) = self.events.back_mut()
 213        {
 214            // Coalesce edits for the same buffer when they happen one after the other.
 215            let Event::BufferChange {
 216                old_snapshot,
 217                new_snapshot,
 218                timestamp,
 219            } = &event;
 220
 221            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 222                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 223                && old_snapshot.version == last_new_snapshot.version
 224            {
 225                *last_new_snapshot = new_snapshot.clone();
 226                *last_timestamp = *timestamp;
 227                return;
 228            }
 229        }
 230
 231        self.events.push_back(event);
 232        if self.events.len() > 10 {
 233            self.events.pop_front();
 234        }
 235    }
 236
 237    pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
 238        let buffer_id = buffer.entity_id();
 239        let weak_buffer = buffer.downgrade();
 240
 241        if let std::collections::hash_map::Entry::Vacant(entry) =
 242            self.registered_buffers.entry(buffer_id)
 243        {
 244            let snapshot = buffer.read(cx).snapshot();
 245
 246            entry.insert(RegisteredBuffer {
 247                snapshot,
 248                _subscriptions: [
 249                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 250                        this.handle_buffer_event(buffer, event, cx);
 251                    }),
 252                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 253                        this.registered_buffers.remove(&weak_buffer.entity_id());
 254                    }),
 255                ],
 256            });
 257        };
 258    }
 259
 260    fn handle_buffer_event(
 261        &mut self,
 262        buffer: Model<Buffer>,
 263        event: &language::BufferEvent,
 264        cx: &mut ModelContext<Self>,
 265    ) {
 266        match event {
 267            language::BufferEvent::Edited => {
 268                self.report_changes_for_buffer(&buffer, cx);
 269            }
 270            _ => {}
 271        }
 272    }
 273
 274    pub fn request_completion_impl<F, R>(
 275        &mut self,
 276        buffer: &Model<Buffer>,
 277        position: language::Anchor,
 278        cx: &mut ModelContext<Self>,
 279        perform_predict_edits: F,
 280    ) -> Task<Result<InlineCompletion>>
 281    where
 282        F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
 283        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 284    {
 285        let snapshot = self.report_changes_for_buffer(buffer, cx);
 286        let point = position.to_point(&snapshot);
 287        let offset = point.to_offset(&snapshot);
 288        let excerpt_range = excerpt_range_for_position(point, &snapshot);
 289        let events = self.events.clone();
 290        let path = snapshot
 291            .file()
 292            .map(|f| f.path().clone())
 293            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 294
 295        let client = self.client.clone();
 296        let llm_token = self.llm_token.clone();
 297
 298        cx.spawn(|this, mut cx| async move {
 299            let request_sent_at = Instant::now();
 300
 301            let input_events = cx
 302                .background_executor()
 303                .spawn(async move {
 304                    let mut input_events = String::new();
 305                    for event in events {
 306                        if !input_events.is_empty() {
 307                            input_events.push('\n');
 308                            input_events.push('\n');
 309                        }
 310                        input_events.push_str(&event.to_prompt());
 311                    }
 312                    input_events
 313                })
 314                .await;
 315
 316            let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
 317            let input_outline = prompt_for_outline(&snapshot);
 318
 319            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 320
 321            let body = PredictEditsParams {
 322                input_events: input_events.clone(),
 323                input_excerpt: input_excerpt.clone(),
 324            };
 325
 326            let response = perform_predict_edits(client, llm_token, body).await?;
 327
 328            let output_excerpt = response.output_excerpt;
 329            log::debug!("completion response: {}", output_excerpt);
 330
 331            let inline_completion = Self::process_completion_response(
 332                output_excerpt,
 333                &snapshot,
 334                excerpt_range,
 335                path,
 336                input_outline,
 337                input_events,
 338                input_excerpt,
 339                request_sent_at,
 340                &cx,
 341            )
 342            .await?;
 343
 344            this.update(&mut cx, |this, cx| {
 345                this.recent_completions
 346                    .push_front(inline_completion.clone());
 347                if this.recent_completions.len() > 50 {
 348                    let completion = this.recent_completions.pop_back().unwrap();
 349                    this.shown_completions.remove(&completion.id);
 350                    this.rated_completions.remove(&completion.id);
 351                }
 352                cx.notify();
 353            })?;
 354
 355            Ok(inline_completion)
 356        })
 357    }
 358
 359    // Generates several example completions of various states to fill the Zeta completion modal
 360    #[cfg(any(test, feature = "test-support"))]
 361    pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 362        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 363            And maybe a short line
 364
 365            Then a few lines
 366
 367            and then another
 368            "#};
 369
 370        let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
 371        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 372
 373        let completion_tasks = vec![
 374            self.fake_completion(
 375                &buffer,
 376                position,
 377                PredictEditsResponse {
 378                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 379a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 380[here's an edit]
 381And maybe a short line
 382Then a few lines
 383and then another
 384{EDITABLE_REGION_END_MARKER}
 385                        ", ),
 386                },
 387                cx,
 388            ),
 389            self.fake_completion(
 390                &buffer,
 391                position,
 392                PredictEditsResponse {
 393                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 394a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 395And maybe a short line
 396[and another edit]
 397Then a few lines
 398and then another
 399{EDITABLE_REGION_END_MARKER}
 400                        "#),
 401                },
 402                cx,
 403            ),
 404            self.fake_completion(
 405                &buffer,
 406                position,
 407                PredictEditsResponse {
 408                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 409a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 410And maybe a short line
 411
 412Then a few lines
 413
 414and then another
 415{EDITABLE_REGION_END_MARKER}
 416                        "#),
 417                },
 418                cx,
 419            ),
 420            self.fake_completion(
 421                &buffer,
 422                position,
 423                PredictEditsResponse {
 424                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 425a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 426And maybe a short line
 427
 428Then a few lines
 429
 430and then another
 431{EDITABLE_REGION_END_MARKER}
 432                        "#),
 433                },
 434                cx,
 435            ),
 436            self.fake_completion(
 437                &buffer,
 438                position,
 439                PredictEditsResponse {
 440                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 441a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 442And maybe a short line
 443Then a few lines
 444[a third completion]
 445and then another
 446{EDITABLE_REGION_END_MARKER}
 447                        "#),
 448                },
 449                cx,
 450            ),
 451            self.fake_completion(
 452                &buffer,
 453                position,
 454                PredictEditsResponse {
 455                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 456a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 457And maybe a short line
 458and then another
 459[fourth completion example]
 460{EDITABLE_REGION_END_MARKER}
 461                        "#),
 462                },
 463                cx,
 464            ),
 465            self.fake_completion(
 466                &buffer,
 467                position,
 468                PredictEditsResponse {
 469                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 470a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 471And maybe a short line
 472Then a few lines
 473and then another
 474[fifth and final completion]
 475{EDITABLE_REGION_END_MARKER}
 476                        "#),
 477                },
 478                cx,
 479            ),
 480        ];
 481
 482        cx.spawn(|zeta, mut cx| async move {
 483            for task in completion_tasks {
 484                task.await.unwrap();
 485            }
 486
 487            zeta.update(&mut cx, |zeta, _cx| {
 488                zeta.recent_completions.get_mut(2).unwrap().edits = Arc::new([]);
 489                zeta.recent_completions.get_mut(3).unwrap().edits = Arc::new([]);
 490            })
 491            .ok();
 492        })
 493    }
 494
 495    #[cfg(any(test, feature = "test-support"))]
 496    pub fn fake_completion(
 497        &mut self,
 498        buffer: &Model<Buffer>,
 499        position: language::Anchor,
 500        response: PredictEditsResponse,
 501        cx: &mut ModelContext<Self>,
 502    ) -> Task<Result<InlineCompletion>> {
 503        use std::future::ready;
 504
 505        self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
 506    }
 507
 508    pub fn request_completion(
 509        &mut self,
 510        buffer: &Model<Buffer>,
 511        position: language::Anchor,
 512        cx: &mut ModelContext<Self>,
 513    ) -> Task<Result<InlineCompletion>> {
 514        self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
 515    }
 516
 517    fn perform_predict_edits(
 518        client: Arc<Client>,
 519        llm_token: LlmApiToken,
 520        body: PredictEditsParams,
 521    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 522        async move {
 523            let http_client = client.http_client();
 524            let mut token = llm_token.acquire(&client).await?;
 525            let mut did_retry = false;
 526
 527            loop {
 528                let request_builder = http_client::Request::builder();
 529                let request = request_builder
 530                    .method(Method::POST)
 531                    .uri(
 532                        http_client
 533                            .build_zed_llm_url("/predict_edits", &[])?
 534                            .as_ref(),
 535                    )
 536                    .header("Content-Type", "application/json")
 537                    .header("Authorization", format!("Bearer {}", token))
 538                    .body(serde_json::to_string(&body)?.into())?;
 539
 540                let mut response = http_client.send(request).await?;
 541
 542                if response.status().is_success() {
 543                    let mut body = String::new();
 544                    response.body_mut().read_to_string(&mut body).await?;
 545                    return Ok(serde_json::from_str(&body)?);
 546                } else if !did_retry
 547                    && response
 548                        .headers()
 549                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 550                        .is_some()
 551                {
 552                    did_retry = true;
 553                    token = llm_token.refresh(&client).await?;
 554                } else {
 555                    let mut body = String::new();
 556                    response.body_mut().read_to_string(&mut body).await?;
 557                    return Err(anyhow!(
 558                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 559                        response.status(),
 560                        body
 561                    ));
 562                }
 563            }
 564        }
 565    }
 566
 567    #[allow(clippy::too_many_arguments)]
 568    fn process_completion_response(
 569        output_excerpt: String,
 570        snapshot: &BufferSnapshot,
 571        excerpt_range: Range<usize>,
 572        path: Arc<Path>,
 573        input_outline: String,
 574        input_events: String,
 575        input_excerpt: String,
 576        request_sent_at: Instant,
 577        cx: &AsyncAppContext,
 578    ) -> Task<Result<InlineCompletion>> {
 579        let snapshot = snapshot.clone();
 580        cx.background_executor().spawn(async move {
 581            let content = output_excerpt.replace(CURSOR_MARKER, "");
 582
 583            let codefence_start = content
 584                .find(EDITABLE_REGION_START_MARKER)
 585                .context("could not find start marker")?;
 586            let content = &content[codefence_start..];
 587
 588            let newline_ix = content.find('\n').context("could not find newline")?;
 589            let content = &content[newline_ix + 1..];
 590
 591            let codefence_end = content
 592                .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 593                .context("could not find end marker")?;
 594            let new_text = &content[..codefence_end];
 595
 596            let old_text = snapshot
 597                .text_for_range(excerpt_range.clone())
 598                .collect::<String>();
 599
 600            let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
 601
 602            Ok(InlineCompletion {
 603                id: InlineCompletionId::new(),
 604                path,
 605                excerpt_range,
 606                edits: edits.into(),
 607                snapshot: snapshot.clone(),
 608                input_outline: input_outline.into(),
 609                input_events: input_events.into(),
 610                input_excerpt: input_excerpt.into(),
 611                output_excerpt: output_excerpt.into(),
 612                request_sent_at,
 613                response_received_at: Instant::now(),
 614            })
 615        })
 616    }
 617
 618    pub fn compute_edits(
 619        old_text: String,
 620        new_text: &str,
 621        offset: usize,
 622        snapshot: &BufferSnapshot,
 623    ) -> Vec<(Range<Anchor>, String)> {
 624        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 625
 626        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 627        let mut old_start = offset;
 628        for change in diff.iter_all_changes() {
 629            let value = change.value();
 630            match change.tag() {
 631                similar::ChangeTag::Equal => {
 632                    old_start += value.len();
 633                }
 634                similar::ChangeTag::Delete => {
 635                    let old_end = old_start + value.len();
 636                    if let Some((last_old_range, _)) = edits.last_mut() {
 637                        if last_old_range.end == old_start {
 638                            last_old_range.end = old_end;
 639                        } else {
 640                            edits.push((old_start..old_end, String::new()));
 641                        }
 642                    } else {
 643                        edits.push((old_start..old_end, String::new()));
 644                    }
 645                    old_start = old_end;
 646                }
 647                similar::ChangeTag::Insert => {
 648                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 649                        if last_old_range.end == old_start {
 650                            last_new_text.push_str(value);
 651                        } else {
 652                            edits.push((old_start..old_start, value.into()));
 653                        }
 654                    } else {
 655                        edits.push((old_start..old_start, value.into()));
 656                    }
 657                }
 658            }
 659        }
 660
 661        edits
 662            .into_iter()
 663            .map(|(mut old_range, new_text)| {
 664                let prefix_len = common_prefix(
 665                    snapshot.chars_for_range(old_range.clone()),
 666                    new_text.chars(),
 667                );
 668                old_range.start += prefix_len;
 669                let suffix_len = common_prefix(
 670                    snapshot.reversed_chars_for_range(old_range.clone()),
 671                    new_text[prefix_len..].chars().rev(),
 672                );
 673                old_range.end = old_range.end.saturating_sub(suffix_len);
 674
 675                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 676                (
 677                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
 678                    new_text,
 679                )
 680            })
 681            .collect()
 682    }
 683
 684    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 685        self.rated_completions.contains(&completion_id)
 686    }
 687
 688    pub fn was_completion_shown(&self, completion_id: InlineCompletionId) -> bool {
 689        self.shown_completions.contains(&completion_id)
 690    }
 691
 692    pub fn completion_shown(&mut self, completion_id: InlineCompletionId) {
 693        self.shown_completions.insert(completion_id);
 694    }
 695
 696    pub fn rate_completion(
 697        &mut self,
 698        completion: &InlineCompletion,
 699        rating: InlineCompletionRating,
 700        feedback: String,
 701        cx: &mut ModelContext<Self>,
 702    ) {
 703        telemetry::event!(
 704            "Inline Completion Rated",
 705            rating,
 706            input_events = completion.input_events,
 707            input_excerpt = completion.input_excerpt,
 708            input_outline = completion.input_outline,
 709            output_excerpt = completion.output_excerpt,
 710            feedback
 711        );
 712        self.client.telemetry().flush_events();
 713        cx.notify();
 714    }
 715
 716    pub fn recent_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 717        self.recent_completions.iter()
 718    }
 719
 720    pub fn recent_completions_len(&self) -> usize {
 721        self.recent_completions.len()
 722    }
 723
 724    fn report_changes_for_buffer(
 725        &mut self,
 726        buffer: &Model<Buffer>,
 727        cx: &mut ModelContext<Self>,
 728    ) -> BufferSnapshot {
 729        self.register_buffer(buffer, cx);
 730
 731        let registered_buffer = self
 732            .registered_buffers
 733            .get_mut(&buffer.entity_id())
 734            .unwrap();
 735        let new_snapshot = buffer.read(cx).snapshot();
 736
 737        if new_snapshot.version != registered_buffer.snapshot.version {
 738            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 739            self.push_event(Event::BufferChange {
 740                old_snapshot,
 741                new_snapshot: new_snapshot.clone(),
 742                timestamp: Instant::now(),
 743            });
 744        }
 745
 746        new_snapshot
 747    }
 748}
 749
 750fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 751    a.zip(b)
 752        .take_while(|(a, b)| a == b)
 753        .map(|(a, _)| a.len_utf8())
 754        .sum()
 755}
 756
 757fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 758    let mut input_outline = String::new();
 759
 760    writeln!(
 761        input_outline,
 762        "```{}",
 763        snapshot
 764            .file()
 765            .map_or(Cow::Borrowed("untitled"), |file| file
 766                .path()
 767                .to_string_lossy())
 768    )
 769    .unwrap();
 770
 771    if let Some(outline) = snapshot.outline(None) {
 772        let guess_size = outline.items.len() * 15;
 773        input_outline.reserve(guess_size);
 774        for item in outline.items.iter() {
 775            let spacing = " ".repeat(item.depth);
 776            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
 777        }
 778    }
 779
 780    writeln!(input_outline, "```").unwrap();
 781
 782    input_outline
 783}
 784
 785fn prompt_for_excerpt(
 786    snapshot: &BufferSnapshot,
 787    excerpt_range: &Range<usize>,
 788    offset: usize,
 789) -> String {
 790    let mut prompt_excerpt = String::new();
 791    writeln!(
 792        prompt_excerpt,
 793        "```{}",
 794        snapshot
 795            .file()
 796            .map_or(Cow::Borrowed("untitled"), |file| file
 797                .path()
 798                .to_string_lossy())
 799    )
 800    .unwrap();
 801
 802    if excerpt_range.start == 0 {
 803        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
 804    }
 805
 806    let point_range = excerpt_range.to_point(snapshot);
 807    if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
 808        let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
 809        for chunk in snapshot.text_for_range(extra_context_line_range) {
 810            prompt_excerpt.push_str(chunk);
 811        }
 812    }
 813    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 814    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
 815        prompt_excerpt.push_str(chunk);
 816    }
 817    prompt_excerpt.push_str(CURSOR_MARKER);
 818    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
 819        prompt_excerpt.push_str(chunk);
 820    }
 821    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 822
 823    if point_range.end.row < snapshot.max_point().row
 824        && !snapshot.is_line_blank(point_range.end.row + 1)
 825    {
 826        let extra_context_line_range = point_range.end
 827            ..Point::new(
 828                point_range.end.row + 1,
 829                snapshot.line_len(point_range.end.row + 1),
 830            );
 831        for chunk in snapshot.text_for_range(extra_context_line_range) {
 832            prompt_excerpt.push_str(chunk);
 833        }
 834    }
 835
 836    write!(prompt_excerpt, "\n```").unwrap();
 837    prompt_excerpt
 838}
 839
 840fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
 841    const CONTEXT_LINES: u32 = 32;
 842
 843    let mut context_lines_before = CONTEXT_LINES;
 844    let mut context_lines_after = CONTEXT_LINES;
 845    if point.row < CONTEXT_LINES {
 846        context_lines_after += CONTEXT_LINES - point.row;
 847    } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
 848        context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
 849    }
 850
 851    let excerpt_start_row = point.row.saturating_sub(context_lines_before);
 852    let excerpt_start = Point::new(excerpt_start_row, 0);
 853    let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
 854    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
 855    excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
 856}
 857
 858struct RegisteredBuffer {
 859    snapshot: BufferSnapshot,
 860    _subscriptions: [gpui::Subscription; 2],
 861}
 862
 863#[derive(Clone)]
 864enum Event {
 865    BufferChange {
 866        old_snapshot: BufferSnapshot,
 867        new_snapshot: BufferSnapshot,
 868        timestamp: Instant,
 869    },
 870}
 871
 872impl Event {
 873    fn to_prompt(&self) -> String {
 874        match self {
 875            Event::BufferChange {
 876                old_snapshot,
 877                new_snapshot,
 878                ..
 879            } => {
 880                let mut prompt = String::new();
 881
 882                let old_path = old_snapshot
 883                    .file()
 884                    .map(|f| f.path().as_ref())
 885                    .unwrap_or(Path::new("untitled"));
 886                let new_path = new_snapshot
 887                    .file()
 888                    .map(|f| f.path().as_ref())
 889                    .unwrap_or(Path::new("untitled"));
 890                if old_path != new_path {
 891                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
 892                }
 893
 894                let diff =
 895                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
 896                        .unified_diff()
 897                        .to_string();
 898                if !diff.is_empty() {
 899                    write!(
 900                        prompt,
 901                        "User edited {:?}:\n```diff\n{}\n```",
 902                        new_path, diff
 903                    )
 904                    .unwrap();
 905                }
 906
 907                prompt
 908            }
 909        }
 910    }
 911}
 912
 913#[derive(Debug, Clone)]
 914struct CurrentInlineCompletion {
 915    buffer_id: EntityId,
 916    completion: InlineCompletion,
 917}
 918
 919impl CurrentInlineCompletion {
 920    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
 921        if self.buffer_id != old_completion.buffer_id {
 922            return true;
 923        }
 924
 925        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
 926            return true;
 927        };
 928        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
 929            return false;
 930        };
 931
 932        if old_edits.len() == 1 && new_edits.len() == 1 {
 933            let (old_range, old_text) = &old_edits[0];
 934            let (new_range, new_text) = &new_edits[0];
 935            new_range == old_range && new_text.starts_with(old_text)
 936        } else {
 937            true
 938        }
 939    }
 940}
 941
 942struct PendingCompletion {
 943    id: usize,
 944    _task: Task<Result<()>>,
 945}
 946
 947pub struct ZetaInlineCompletionProvider {
 948    zeta: Model<Zeta>,
 949    pending_completions: ArrayVec<PendingCompletion, 2>,
 950    next_pending_completion_id: usize,
 951    current_completion: Option<CurrentInlineCompletion>,
 952}
 953
 954impl ZetaInlineCompletionProvider {
 955    pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
 956
 957    pub fn new(zeta: Model<Zeta>) -> Self {
 958        Self {
 959            zeta,
 960            pending_completions: ArrayVec::new(),
 961            next_pending_completion_id: 0,
 962            current_completion: None,
 963        }
 964    }
 965}
 966
 967impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
 968    fn name() -> &'static str {
 969        "zeta"
 970    }
 971
 972    fn display_name() -> &'static str {
 973        "Zeta"
 974    }
 975
 976    fn show_completions_in_menu() -> bool {
 977        true
 978    }
 979
 980    fn is_enabled(
 981        &self,
 982        buffer: &Model<Buffer>,
 983        cursor_position: language::Anchor,
 984        cx: &AppContext,
 985    ) -> bool {
 986        let buffer = buffer.read(cx);
 987        let file = buffer.file();
 988        let language = buffer.language_at(cursor_position);
 989        let settings = all_language_settings(file, cx);
 990        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
 991    }
 992
 993    fn refresh(
 994        &mut self,
 995        buffer: Model<Buffer>,
 996        position: language::Anchor,
 997        debounce: bool,
 998        cx: &mut ModelContext<Self>,
 999    ) {
1000        let pending_completion_id = self.next_pending_completion_id;
1001        self.next_pending_completion_id += 1;
1002
1003        let task = cx.spawn(|this, mut cx| async move {
1004            if debounce {
1005                cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1006            }
1007
1008            let completion_request = this.update(&mut cx, |this, cx| {
1009                this.zeta.update(cx, |zeta, cx| {
1010                    zeta.request_completion(&buffer, position, cx)
1011                })
1012            });
1013
1014            let mut completion = None;
1015            if let Ok(completion_request) = completion_request {
1016                completion = Some(CurrentInlineCompletion {
1017                    buffer_id: buffer.entity_id(),
1018                    completion: completion_request.await?,
1019                });
1020            }
1021
1022            this.update(&mut cx, |this, cx| {
1023                if this.pending_completions[0].id == pending_completion_id {
1024                    this.pending_completions.remove(0);
1025                } else {
1026                    this.pending_completions.clear();
1027                }
1028
1029                if let Some(new_completion) = completion {
1030                    if let Some(old_completion) = this.current_completion.as_ref() {
1031                        let snapshot = buffer.read(cx).snapshot();
1032                        if new_completion.should_replace_completion(&old_completion, &snapshot) {
1033                            this.zeta.update(cx, |zeta, _cx| {
1034                                zeta.completion_shown(new_completion.completion.id)
1035                            });
1036                            this.current_completion = Some(new_completion);
1037                        }
1038                    } else {
1039                        this.zeta.update(cx, |zeta, _cx| {
1040                            zeta.completion_shown(new_completion.completion.id)
1041                        });
1042                        this.current_completion = Some(new_completion);
1043                    }
1044                } else {
1045                    this.current_completion = None;
1046                }
1047
1048                cx.notify();
1049            })
1050        });
1051
1052        // We always maintain at most two pending completions. When we already
1053        // have two, we replace the newest one.
1054        if self.pending_completions.len() <= 1 {
1055            self.pending_completions.push(PendingCompletion {
1056                id: pending_completion_id,
1057                _task: task,
1058            });
1059        } else if self.pending_completions.len() == 2 {
1060            self.pending_completions.pop();
1061            self.pending_completions.push(PendingCompletion {
1062                id: pending_completion_id,
1063                _task: task,
1064            });
1065        }
1066    }
1067
1068    fn cycle(
1069        &mut self,
1070        _buffer: Model<Buffer>,
1071        _cursor_position: language::Anchor,
1072        _direction: inline_completion::Direction,
1073        _cx: &mut ModelContext<Self>,
1074    ) {
1075        // Right now we don't support cycling.
1076    }
1077
1078    fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1079        self.pending_completions.clear();
1080    }
1081
1082    fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1083        self.pending_completions.clear();
1084        self.current_completion.take();
1085    }
1086
1087    fn suggest(
1088        &mut self,
1089        buffer: &Model<Buffer>,
1090        cursor_position: language::Anchor,
1091        cx: &mut ModelContext<Self>,
1092    ) -> Option<inline_completion::InlineCompletion> {
1093        let CurrentInlineCompletion {
1094            buffer_id,
1095            completion,
1096            ..
1097        } = self.current_completion.as_mut()?;
1098
1099        // Invalidate previous completion if it was generated for a different buffer.
1100        if *buffer_id != buffer.entity_id() {
1101            self.current_completion.take();
1102            return None;
1103        }
1104
1105        let buffer = buffer.read(cx);
1106        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1107            self.current_completion.take();
1108            return None;
1109        };
1110
1111        let cursor_row = cursor_position.to_point(buffer).row;
1112        let (closest_edit_ix, (closest_edit_range, _)) =
1113            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1114                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1115                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1116                cmp::min(distance_from_start, distance_from_end)
1117            })?;
1118
1119        let mut edit_start_ix = closest_edit_ix;
1120        for (range, _) in edits[..edit_start_ix].iter().rev() {
1121            let distance_from_closest_edit =
1122                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1123            if distance_from_closest_edit <= 1 {
1124                edit_start_ix -= 1;
1125            } else {
1126                break;
1127            }
1128        }
1129
1130        let mut edit_end_ix = closest_edit_ix + 1;
1131        for (range, _) in &edits[edit_end_ix..] {
1132            let distance_from_closest_edit =
1133                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1134            if distance_from_closest_edit <= 1 {
1135                edit_end_ix += 1;
1136            } else {
1137                break;
1138            }
1139        }
1140
1141        Some(inline_completion::InlineCompletion {
1142            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1143        })
1144    }
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149    use client::test::FakeServer;
1150    use clock::FakeSystemClock;
1151    use gpui::TestAppContext;
1152    use http_client::FakeHttpClient;
1153    use indoc::indoc;
1154    use language_models::RefreshLlmTokenListener;
1155    use rpc::proto;
1156    use settings::SettingsStore;
1157
1158    use super::*;
1159
1160    #[gpui::test]
1161    fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1162        let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1163        let completion = InlineCompletion {
1164            edits: to_completion_edits(
1165                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1166                &buffer,
1167                cx,
1168            )
1169            .into(),
1170            path: Path::new("").into(),
1171            snapshot: buffer.read(cx).snapshot(),
1172            id: InlineCompletionId::new(),
1173            excerpt_range: 0..0,
1174            input_outline: "".into(),
1175            input_events: "".into(),
1176            input_excerpt: "".into(),
1177            output_excerpt: "".into(),
1178            request_sent_at: Instant::now(),
1179            response_received_at: Instant::now(),
1180        };
1181
1182        assert_eq!(
1183            from_completion_edits(
1184                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1185                &buffer,
1186                cx
1187            ),
1188            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1189        );
1190
1191        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1192        assert_eq!(
1193            from_completion_edits(
1194                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1195                &buffer,
1196                cx
1197            ),
1198            vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1199        );
1200
1201        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1202        assert_eq!(
1203            from_completion_edits(
1204                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1205                &buffer,
1206                cx
1207            ),
1208            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1209        );
1210
1211        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1212        assert_eq!(
1213            from_completion_edits(
1214                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1215                &buffer,
1216                cx
1217            ),
1218            vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1219        );
1220
1221        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1222        assert_eq!(
1223            from_completion_edits(
1224                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1225                &buffer,
1226                cx
1227            ),
1228            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1229        );
1230
1231        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1232        assert_eq!(
1233            from_completion_edits(
1234                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1235                &buffer,
1236                cx
1237            ),
1238            vec![(9..11, "".to_string())]
1239        );
1240
1241        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1242        assert_eq!(
1243            from_completion_edits(
1244                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1245                &buffer,
1246                cx
1247            ),
1248            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1249        );
1250
1251        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1252        assert_eq!(
1253            from_completion_edits(
1254                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1255                &buffer,
1256                cx
1257            ),
1258            vec![(4..4, "M".to_string())]
1259        );
1260
1261        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1262        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1263    }
1264
1265    #[gpui::test]
1266    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1267        cx.update(|cx| {
1268            let settings_store = SettingsStore::test(cx);
1269            cx.set_global(settings_store);
1270            client::init_settings(cx);
1271        });
1272
1273        let buffer_content = "lorem\n";
1274        let completion_response = indoc! {"
1275            ```animals.js
1276            <|start_of_file|>
1277            <|editable_region_start|>
1278            lorem
1279            ipsum
1280            <|editable_region_end|>
1281            ```"};
1282
1283        let http_client = FakeHttpClient::create(move |_| async move {
1284            Ok(http_client::Response::builder()
1285                .status(200)
1286                .body(
1287                    serde_json::to_string(&PredictEditsResponse {
1288                        output_excerpt: completion_response.to_string(),
1289                    })
1290                    .unwrap()
1291                    .into(),
1292                )
1293                .unwrap())
1294        });
1295
1296        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1297        cx.update(|cx| {
1298            RefreshLlmTokenListener::register(client.clone(), cx);
1299        });
1300        let server = FakeServer::for_client(42, &client, cx).await;
1301
1302        let zeta = cx.new_model(|cx| Zeta::new(client, cx));
1303        let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1304        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1305        let completion_task =
1306            zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1307
1308        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1309        server.respond(
1310            token_request.receipt(),
1311            proto::GetLlmTokenResponse { token: "".into() },
1312        );
1313
1314        let completion = completion_task.await.unwrap();
1315        buffer.update(cx, |buffer, cx| {
1316            buffer.edit(completion.edits.iter().cloned(), None, cx)
1317        });
1318        assert_eq!(
1319            buffer.read_with(cx, |buffer, _| buffer.text()),
1320            "lorem\nipsum"
1321        );
1322    }
1323
1324    fn to_completion_edits(
1325        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1326        buffer: &Model<Buffer>,
1327        cx: &AppContext,
1328    ) -> Vec<(Range<Anchor>, String)> {
1329        let buffer = buffer.read(cx);
1330        iterator
1331            .into_iter()
1332            .map(|(range, text)| {
1333                (
1334                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1335                    text,
1336                )
1337            })
1338            .collect()
1339    }
1340
1341    fn from_completion_edits(
1342        editor_edits: &[(Range<Anchor>, String)],
1343        buffer: &Model<Buffer>,
1344        cx: &AppContext,
1345    ) -> Vec<(Range<usize>, String)> {
1346        let buffer = buffer.read(cx);
1347        editor_edits
1348            .iter()
1349            .map(|(range, text)| {
1350                (
1351                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1352                    text.clone(),
1353                )
1354            })
1355            .collect()
1356    }
1357
1358    #[ctor::ctor]
1359    fn init_logger() {
1360        if std::env::var("RUST_LOG").is_ok() {
1361            env_logger::init();
1362        }
1363    }
1364}